Merge remote-tracking branch 'cntk/master' into merge

This commit is contained in:
Scott Cyphers 2015-09-17 09:32:17 -04:00
Родитель 1036728875 7f88e5b771
Коммит cf29bf0f38
213 изменённых файлов: 34130 добавлений и 38404 удалений

Двоичный файл не отображается.

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,29 @@
// BrainScriptEvaluator.h -- execute what's given in a config file
#pragma once
#include "Basics.h"
#include "ScriptableObjects.h"
#include "BrainScriptParser.h"
#include <memory> // for shared_ptr
namespace Microsoft { namespace MSR { namespace BS {
using namespace std;
using namespace Microsoft::MSR::ScriptableObjects;
// -----------------------------------------------------------------------
// functions exposed by this module
// TODO: This is the only thing that should stay in an actual BrainScriptEvaluator.h.
// -----------------------------------------------------------------------
// understand and execute from the syntactic expression tree
ConfigValuePtr Evaluate(ExpressionPtr); // evaluate the expression tree
void Do(ExpressionPtr e); // evaluate e.do
shared_ptr<Object> EvaluateField(ExpressionPtr e, const wstring & id); // for experimental CNTK integration
// some simple tests
void SomeTests();
}}} // end namespaces

Просмотреть файл

@ -0,0 +1,797 @@
// ConfigParser.cpp -- config parser (syntactic only, that is, source -> Expression tree)
#define _CRT_SECURE_NO_WARNINGS // "secure" CRT not available on all platforms --add this at the top of all CPP files that give "function or variable may be unsafe" warnings
#include "BrainScriptParser.h"
#include <cstdio>
#include <cstdlib>
#include <cctype>
#include <string>
#include <vector>
#include <deque>
#include <set>
#include <stdexcept>
#include <algorithm>
#ifndef let
#define let const auto
#endif
namespace Microsoft { namespace MSR { namespace BS {
using namespace std;
using namespace msra::strfun;
using namespace Microsoft::MSR::CNTK;
// ---------------------------------------------------------------------------
// source files and text references (location) into them
// ---------------------------------------------------------------------------
// SourceFile constructors
SourceFile::SourceFile(wstring location, wstring text) : path(location), lines(split(text, L"\r\n")) { } // from string, e.g. command line
SourceFile::SourceFile(wstring path) : path(path) // from file
{
File(path, fileOptionsRead).GetLines(lines);
}
bool TextLocation::IsValid() const { return sourceFileAsIndex != SIZE_MAX; }
// register a new source file and return a TextPosition that points to its start
/*static*/ TextLocation TextLocation::NewSourceFile(SourceFile && sourceFile)
{
TextLocation loc;
loc.lineNo = 0;
loc.charPos = 0;
loc.sourceFileAsIndex = sourceFileMap.size(); // index under which we store the source file
sourceFileMap.push_back(move(sourceFile)); // take ownership of the source file and give it a numeric index
return loc;
}
// helper for pretty-printing errors: Show source-code line with ...^ under it to mark up the point of error
struct Issue
{
TextLocation location; // using lineno and source file; char position only for printing the overall error loc
wstring markup; // string with markup symbols at char positions and dots inbetween
void AddMarkup(wchar_t symbol, size_t charPos)
{
if (charPos >= markup.size())
markup.resize(charPos+1, L' '); // fill with '.' up to desired position if the string is not that long yet
if (markup[charPos] == L' ') // don't overwrite
markup[charPos] = symbol;
}
Issue(TextLocation location) : location(location) { }
};
// trace
/*static*/ void TextLocation::Trace(TextLocation location, const wchar_t * traceKind, const wchar_t * op, const wchar_t * exprPath)
{
fprintf(stderr, "%ls: %ls (path %ls)\n", traceKind, op, exprPath);
const auto & lines = location.GetSourceFile().lines;
const auto line = (location.lineNo == lines.size()) ? L"(end)" : lines[location.lineNo].c_str();
Issue issue(location);
issue.AddMarkup(L'^', location.charPos);
fprintf(stderr, " %ls\n %ls\n", line, issue.markup.c_str());
}
// report an error
// The source line is shown, and the position is marked as '^'.
// Because it is often hard to recognize an issue only from the point where it occurred, we also report the history in compact visual form.
// Since often multiple contexts are on the same source line, we only print each source line once in a consecutive row of contexts.
/*static*/ void TextLocation::PrintIssue(const vector<TextLocation> & locations, const wchar_t * errorKind, const wchar_t * kind, const wchar_t * what)
{
vector<Issue> issues; // tracing the error backwards
size_t symbolIndex = 0;
for (size_t n = 0; n < locations.size(); n++)
{
let & location = locations[n];
if (!location.IsValid()) // means thrower has no location, go up one context
continue;
// build the array
if (symbolIndex == 0 || location.lineNo != issues.back().location.lineNo || location.sourceFileAsIndex != issues.back().location.sourceFileAsIndex)
{
if (issues.size() == 10)
break;
else
issues.push_back(location);
}
// get the symbol to indicate how many steps back, in this sequence: ^ 0..9 a..z A..Z (we don't go further than this)
wchar_t symbol;
if (symbolIndex == 0) symbol = '^';
else if (symbolIndex < 1 + 10) symbol = '0' + (wchar_t)symbolIndex - 1;
else if (symbolIndex < 1 + 10 + 26) symbol = 'a' + (wchar_t)symbolIndex - (1 + 10);
else if (symbolIndex < 1 + 10 + 26 + 26) symbol = 'A' + (wchar_t)symbolIndex - (1 + 10 + 26);
else break;
symbolIndex++;
// insert the markup
issues.back().AddMarkup(symbol, location.charPos);
}
// print it backwards
if (!locations.empty()) // (be resilient to some throwers not having a TextrLocation; to be avoided)
{
let & firstLoc = issues.front().location;
fprintf(stderr, "\n%ls while %ls line %d char %d of %ls\n", errorKind, kind, (int)firstLoc.lineNo + 1/*report 1-based*/, (int)firstLoc.charPos + 1, firstLoc.GetSourceFile().path.c_str());
fprintf(stderr, "see location marked ^ and parent contexts marked 0..9, a..z, A..Z:\n\n");
for (auto i = issues.rbegin(); i != issues.rend(); i++)
{
let & issue = *i;
auto & where = issue.location;
const auto & lines = where.GetSourceFile().lines;
const auto line = (where.lineNo == lines.size()) ? L"(end)" : lines[where.lineNo].c_str();
fprintf(stderr, " %ls\n %ls\n", line, issue.markup.c_str());
}
}
fprintf(stderr, "%ls: %ls\n", errorKind, what);
fflush(stderr);
}
/*static*/ vector<SourceFile> TextLocation::sourceFileMap;
// ---------------------------------------------------------------------------
// reader -- reads source code, including loading from disk
// ---------------------------------------------------------------------------
class CodeSource
{
vector<TextLocation> locationStack; // parent locations in case of included files
TextLocation cursor; // current location
const wchar_t * currentLine; // cache of cursor.GetSourceFile().lines[cursor.lineNo]
// update currentLine from cursor
void CacheCurrentLine()
{
let & lines = cursor.GetSourceFile().lines;
if (cursor.lineNo == lines.size())
currentLine = nullptr;
else
currentLine = lines[cursor.lineNo].c_str();
}
protected:
// set a source file; only do that from constructor or inside PushSourceFile()
void SetSourceFile(SourceFile && sourceFile)
{
cursor = TextLocation::NewSourceFile(move(sourceFile)); // save source file and set the cursor to its start
CacheCurrentLine(); // re-cache current line
}
public:
class CodeSourceError : public ConfigError
{
public:
CodeSourceError(const wstring & msg, TextLocation where) : ConfigError(msg, where) { }
/*ConfigError::*/ const wchar_t * kind() const { return L"reading source"; }
};
__declspec_noreturn static void Fail(wstring msg, TextLocation where) { throw CodeSourceError(msg, where); }
// enter a source file, at start or as a result of an include statement
void PushSourceFile(SourceFile && sourceFile)
{
locationStack.push_back(cursor);
SetSourceFile(move(sourceFile));
}
// are we inside an include file?
bool IsInInclude() { return locationStack.size() > 0; }
// done with a source file. Only call this for nested files; the outermost one must not be popped.
void PopSourceFile()
{
if (!IsInInclude())
LogicError("PopSourceFile: location stack empty");
cursor = locationStack.back(); // restore cursor we came from
CacheCurrentLine(); // re-cache current line
locationStack.pop_back();
}
// get current cursor; this is remembered for each token, and also used when throwing errors
TextLocation GetCursor() const { return cursor; }
// get character at current position.
// Special cases:
// - end of line is returned as '\n'
// - end of file is returned as 0
wchar_t GotChar() const
{
if (!currentLine) return 0; // end of file
else if (!currentLine[cursor.charPos]) return '\n'; // end of line
else return currentLine[cursor.charPos];
}
// we chan also return the address of the current character, e.g. for passing it to a C stdlib funcion such as wcstod()
const wchar_t * GotCharPtr() const { return currentLine + cursor.charPos; }
// advance cursor by #chars (but across line boundaries)
void ConsumeChars(size_t chars)
{
let ch = GotChar();
if (!ch) LogicError("Consume: cannot run beyond end of source file");
if (ch == '\n' && chars > 0)
{
if (chars != 1) LogicError("Consume: cannot run beyond end of line");
cursor.lineNo++;
CacheCurrentLine(); // line no has changed: re-cache the line ptr
cursor.charPos = 0;
}
else
cursor.charPos += chars;
}
// get the next character
wchar_t GetChar()
{
ConsumeChars(1);
return GotChar();
}
};
// ---------------------------------------------------------------------------
// lexer -- iterates over the source code and returns token by token
// ---------------------------------------------------------------------------
class Lexer : public CodeSource
{
set<wstring> keywords;
set<wstring> punctuations;
public:
Lexer() : CodeSource(), currentToken(TextLocation())
{
keywords = set<wstring>
{
L"include",
L"new", L"true", L"false",
L"if", L"then", L"else",
L"array",
};
punctuations = set<wstring>
{
L"=", L";", L",", L"\n",
L"[", L"]", L"(", L")",
L"+", L"-", L"*", L"/", L"**", L".*", L"%", L"||", L"&&", L"^",
L"!",
L"==", L"!=", L"<", L"<=", L">", L">=",
L":", L"=>",
L"..", L".",
L"//", L"#", L"/*"
};
}
enum TokenKind
{
invalid, punctuation, numberliteral, stringliteral, booleanliter, identifier, keyword, eof // TODO: what are true and false? Literals or identifiers?
};
struct Token
{
wstring symbol; // identifier, keyword, punctuation, or string literal
double number; // number
TokenKind kind;
TextLocation beginLocation; // text loc of first character of this token
bool isLineInitial; // this token is the first on the line (ignoring comments)
Token(TextLocation loc) : beginLocation(loc), kind(invalid), number(0.0), isLineInitial(false) { }
// diagnostic helper
static wstring TokenKindToString(TokenKind kind)
{
switch (kind)
{
case invalid: return L"invalid";
case punctuation: return L"punctuation";
case numberliteral: return L"numberliteral";
case stringliteral: return L"stringliteral";
case identifier: return L"identifier";
case keyword: return L"keyword";
case eof: return L"eof";
default: return L"(unknown?)";
}
}
wstring ToString() const // string to show the content of token for debugging
{
let kindStr = TokenKindToString(kind);
switch (kind)
{
case numberliteral: return kindStr + wstrprintf(L" %f", number);
case stringliteral: return kindStr + L" '" + symbol + L"'";
case identifier: case keyword: case punctuation: return kindStr + L" " + symbol;
default: return kindStr;
}
}
};
class LexerError : public ConfigError
{
public:
LexerError(const wstring & msg, TextLocation where) : ConfigError(msg, where) { }
/*ConfigError::*/ const wchar_t * kind() const { return L"tokenizing"; }
};
private:
__declspec_noreturn static void Fail(wstring msg, Token where) { throw LexerError(msg, where.beginLocation); }
Token currentToken;
// consume input characters to form a next token
// - this function mutates the cursor, but does not set currentToken
// - white space and comments are skipped
// - including files is handled here
// - the cursor is left on the first character that does not belong to the token
// TODO: need to know whether we want to see '\n' or not
Token NextToken()
{
auto ch = GotChar();
// skip white space
// We remember whether we crossed a line end. Dictionary assignments end at newlines if syntactically acceptable.
bool crossedLineEnd = (GetCursor().lineNo == 0 && GetCursor().charPos == 0);
while (iswblank(ch) || ch == '\n' || ch == '\r')
{
crossedLineEnd |= (ch == '\n' || ch == '\r');
ch = GetChar();
}
Token t(GetCursor());
t.isLineInitial = crossedLineEnd;
// handle end of (include) file
if (ch == 0)
{
if (IsInInclude())
{
PopSourceFile();
t = NextToken(); // tail call--the current 't' gets dropped/ignored
t.isLineInitial = true; // eof is a line end
return t;
}
// really end of all source code: we are done. If calling this function multiple times, we will keep returning this.
t.kind = eof;
}
else if (iswdigit(ch) || (ch == L'.' && iswdigit(GotCharPtr()[1]))) // --- number
{
let beginPtr = GotCharPtr();
wchar_t * endPtr = nullptr;
t.number = wcstod(beginPtr, &endPtr); // BUGBUG: this seems to honor locale settings. We need one that doesn't. With this, CNTK won't parse right in Germany.
if (endPtr == beginPtr) Fail(L"parsing number", t); // should not really happen!
t.kind = numberliteral;
if (endPtr[0] == L'.' && endPtr[-1] == L'.') // prevent 1..2 from begin tokenized 1. .2
endPtr--;
ConsumeChars(endPtr - beginPtr);
}
else if (iswalpha(ch) || ch == L'_') // --- identifier or keyword
{
while (iswalpha(ch) || ch == L'_' || iswdigit(ch)) // inside we also allow digits
{
t.symbol.push_back(ch);
ch = GetChar();
}
// check against keyword list
if (keywords.find(t.symbol) != keywords.end()) t.kind = keyword;
else t.kind = identifier;
// special case: include "path"
if (t.symbol == L"include")
{
let nameTok = NextToken(); // must be followed by a string literal
if (nameTok.kind != stringliteral) Fail(L"'include' must be followed by a quoted string", nameTok);
let path = nameTok.symbol; // TODO: some massaging of the path
PushSourceFile(SourceFile(path)); // current cursor is right after the pathname; that's where we will pick up later
return NextToken();
}
}
else if (ch == L'"' || ch == 0x27) // --- string literal
{
t.kind = stringliteral;
let q = ch; // remember quote character
ch = GetChar(); // consume the quote character
while (ch != 0 && ch != q) // note: our strings do not have any escape characters to consider
{
t.symbol.append(1, ch);
ch = GetChar();
}
if (ch == 0) // runaway string
Fail(L"string without closing quotation mark", t);
GetChar(); // consume the closing quote
}
else // --- punctuation
{
t.kind = punctuation;
t.symbol = ch;
t.symbol.append(1, GetChar()); // first try two-char punctuation
if (punctuations.find(t.symbol) != punctuations.end())
GetChar(); // it is a two-char one: need to consume the second one of them
else // try single-char one
{
t.symbol.pop_back(); // drop the last one & try again
if (punctuations.find(t.symbol) == punctuations.end()) // unknown
Fail(L"unexpected character: " + t.symbol, t);
}
// special case: comments
if (t.symbol == L"#" || t.symbol == L"//")
{
ConsumeChars(wcslen(GotCharPtr()));
return NextToken();
}
else if (t.symbol == L"/*")
{
ch = GotChar();
while (ch != 0 && !(ch == L'*' && GetChar() == L'/')) // note: this test leverages short-circuit evaluation semantics of C
ch = GetChar();
if (ch == 0)
Fail(L"comment without closing */", t);
GetChar(); // consume the final '/'
return NextToken(); // and return the next token
}
}
return t;
}
public:
const Token & GotToken() { return currentToken; }
void ConsumeToken() { currentToken = NextToken(); }
const Token & GetToken()
{
ConsumeToken();
return GotToken();
}
// some simple test function
void Test()
{
let lexerTest = L"new CNTK [ do = (train:eval) # main\ntrain=/*test * */if eval include 'c:/me/test.txt' then 13 else array[1..10](i=>i*i); eval=\"a\"+'b' // line-end\n ] 'a\nb\nc' new";
PushSourceFile(SourceFile(L"(command line)", lexerTest));
while (GotToken().kind != Lexer::TokenKind::eof)
{
let & token = GotToken(); // get first token
fprintf(stderr, "%ls\n", token.ToString().c_str());
ConsumeToken();
}
Fail(L"error test", GetCursor());
}
};
// ---------------------------------------------------------------------------
// parser -- parses configurations
// ---------------------------------------------------------------------------
// diagnostics helper: print the content
void Expression::Dump(int indent) const
{
fprintf(stderr, "%*s", indent, "");
if (op == L"s") fprintf(stderr, "'%ls' ", s.c_str());
else if (op == L"d") fprintf(stderr, "%.f ", d);
else if (op == L"b") fprintf(stderr, "%s ", b ? "true" : "false");
else if (op == L"id") fprintf(stderr, "%ls ", id.c_str());
else if (op == L"new" || op == L"array" || op == L".") fprintf(stderr, "%ls %ls ", op.c_str(), id.c_str());
else fprintf(stderr, "%ls ", op.c_str());
if (!args.empty())
{
fprintf(stderr, "\n");
for (const auto & arg : args)
arg->Dump(indent + 2);
}
if (!namedArgs.empty())
{
fprintf(stderr, "\n");
for (const auto & arg : namedArgs)
{
fprintf(stderr, "%*s%ls =\n", indent + 2, "", arg.first.c_str());
arg.second.second->Dump(indent + 4);
}
}
fprintf(stderr, "\n");
}
class Parser : public Lexer
{
// errors
class ParseError : public ConfigError
{
public:
ParseError(const wstring & msg, TextLocation where) : ConfigError(msg, where) { }
/*ConfigError::*/ const wchar_t * kind() const { return L"parsing"; }
};
__declspec_noreturn static void Fail(const wstring & msg, Token where) { throw ParseError(msg, where.beginLocation); }
//void Expected(const wstring & what) { Fail(strprintf("%ls expected", what.c_str()), GotToken().beginLocation); } // I don't know why this does not work
void Expected(const wstring & what) { Fail(what + L" expected", GotToken().beginLocation); }
// this token must be punctuation 's'; check and get the next
void ConsumePunctuation(const wchar_t * s)
{
let & tok = GotToken();
if (tok.kind != punctuation || tok.symbol != s)
Expected(L"'" + wstring(s) + L"'");
ConsumeToken();
}
// this token must be keyword 's'; check and get the next
void ConsumeKeyword(const wchar_t * s)
{
let & tok = GotToken();
if (tok.kind != keyword || tok.symbol != s)
Expected(L"'" + wstring(s) + L"'");
ConsumeToken();
}
// this token must be an identifier; check and get the next token. Return the identifier.
wstring ConsumeIdentifier()
{
let & tok = GotToken();
if (tok.kind != identifier)
Expected(L"identifier");
let id = tok.symbol;
ConsumeToken();
return id;
}
map<wstring, int> infixPrecedence; // precedence level of infix operators
public:
Parser(SourceFile && sourceFile) : Lexer()
{
infixPrecedence = map<wstring, int>
{
{ L".", 100 }, { L"[", 100 }, { L"(", 100 }, // also sort-of infix operands...
{ L"*", 10 }, { L"/", 10 }, { L".*", 10 }, { L"**", 10 }, { L"%", 10 },
{ L"+", 9 }, { L"-", 9 },
{ L"==", 8 }, { L"!=", 8 }, { L"<", 8 }, { L"<=", 8 }, { L">", 8 }, { L">=", 8 },
{ L"&&", 7 },
{ L"||", 6 },
{ L":", 5 },
{ L"=>", 0 },
};
SetSourceFile(move(sourceFile));
ConsumeToken(); // get the very first token
}
ExpressionPtr OperandFromTokenSymbol(const Token & tok) // helper to make an Operand expression with op==tok.symbol and then consume it
{
auto operand = make_shared<Expression>(tok.beginLocation, tok.symbol);
ConsumeToken();
return operand;
}
ExpressionPtr ParseOperand(bool stopAtNewline)
{
let & tok = GotToken();
ExpressionPtr operand;
if (tok.kind == numberliteral) // === numeral literal
{
operand = make_shared<Expression>(tok.beginLocation, L"d", tok.number, wstring(), false);
ConsumeToken();
}
else if (tok.kind == stringliteral) // === string literal
{
operand = make_shared<Expression>(tok.beginLocation, L"s", 0.0, tok.symbol, false);
ConsumeToken();
}
else if (tok.symbol == L"true" || tok.symbol == L"false") // === boolean literal
{
operand = make_shared<Expression>(tok.beginLocation, L"b", 0.0, wstring(), (tok.symbol == L"true"));
ConsumeToken();
}
else if (tok.kind == identifier) // === dict member (unqualified)
{
operand = make_shared<Expression>(tok.beginLocation, L"id");
operand->id = ConsumeIdentifier();
}
else if (tok.symbol == L"+" || tok.symbol == L"-" // === unary operators
|| tok.symbol == L"!")
{
operand = make_shared<Expression>(tok.beginLocation, tok.symbol + L"("); // encoded as +( -( !(
ConsumeToken();
operand->args.push_back(ParseExpression(100, stopAtNewline));
}
else if (tok.symbol == L"new") // === new class instance
{
operand = OperandFromTokenSymbol(tok);
operand->id = ConsumeIdentifier();
operand->args.push_back(ParseOperand(stopAtNewline));
}
else if (tok.symbol == L"if") // === conditional expression
{
operand = OperandFromTokenSymbol(tok);
operand->args.push_back(ParseExpression(0, false)); // [0] condition
ConsumeKeyword(L"then");
operand->args.push_back(ParseExpression(0, false)); // [1] then expression
ConsumeKeyword(L"else");
operand->args.push_back(ParseExpression(0, false)); // [2] else expression
}
else if (tok.symbol == L"(") // === nested parentheses
{
ConsumeToken();
operand = ParseExpression(0, false/*go across newlines*/); // just return the content of the parens (they do not become part of the expression tree)
ConsumePunctuation(L")");
}
else if (tok.symbol == L"[") // === dictionary constructor
{
operand = make_shared<Expression>(tok.beginLocation, L"[]");
ConsumeToken();
operand->namedArgs = ParseRecordMembers();
ConsumePunctuation(L"]");
}
else if (tok.symbol == L"array") // === array constructor
{
operand = OperandFromTokenSymbol(tok);
ConsumePunctuation(L"[");
operand->args.push_back(ParseExpression(0, false)); // [0] first index
ConsumePunctuation(L"..");
operand->args.push_back(ParseExpression(0, false)); // [1] last index
ConsumePunctuation(L"]");
ConsumePunctuation(L"(");
operand->args.push_back(ParseExpression(0, false)); // [2] one-argument lambda to initialize
ConsumePunctuation(L")");
}
else
Expected(L"operand");
return operand; // not using returns above to avoid "not all control paths return a value"
}
ExpressionPtr ParseExpression(int requiredPrecedence, bool stopAtNewline)
{
auto left = ParseOperand(stopAtNewline); // get first operand
for (;;)
{
let & opTok = GotToken();
if (stopAtNewline && opTok.isLineInitial)
break;
let opIter = infixPrecedence.find(opTok.symbol);
if (opIter == infixPrecedence.end()) // not an infix operator: we are done here, 'left' is our expression
break;
let opPrecedence = opIter->second;
if (opPrecedence < requiredPrecedence) // operator below required precedence level: does not belong to this sub-expression
break;
let op = opTok.symbol;
auto operation = make_shared<Expression>(opTok.beginLocation, op, left); // [0] is left operand; we will add [1] except for macro application
// deal with special cases first
// We treat member lookup (.), macro application (a()), and indexing (a[i]) together with the true infix operators.
if (op == L".") // === reference of a dictionary item
{
ConsumeToken();
operation->location = GotToken().beginLocation; // location of the identifier after the .
operation->id = ConsumeIdentifier();
}
else if (op == L"=>")
{
if (left->op != L"id") // currently only allow for a single argument
Expected(L"identifier");
ConsumeToken();
let macroArgs = make_shared<Expression>(left->location, L"()", left); // wrap identifier in a '()' macro-args expression
// TODO: test parsing of i => j => i*j
let body = ParseExpression(opPrecedence, stopAtNewline); // pass same precedence; this makes '=>' right-associative e.g.i=>j=>i*j
operation->args[0] = macroArgs; // [0]: parameter list
operation->args.push_back(body); // [1]: right operand
}
else if (op == L"(") // === macro application
{
// op = "(" means 'apply'
// args[0] = lambda expression (lambda: op="=>", args[0] = param list, args[1] = expression with unbound vars)
// args[1] = arguments (arguments: op="(), args=vector of expressions, one per arg; and namedArgs)
operation->args.push_back(ParseMacroArgs(false)); // [1]: all arguments
}
else if (op == L"[") // === array index
{
ConsumeToken();
operation->args.push_back(ParseExpression(0, false)); // [1]: index
ConsumePunctuation(L"]");
}
else // === regular infix operator
{
ConsumeToken();
let right = ParseExpression(opPrecedence + 1, stopAtNewline); // get right operand, or entire multi-operand expression with higher precedence
operation->args.push_back(right); // [1]: right operand
}
left = operation;
}
return left;
}
// a macro-args expression lists position-dependent and optional parameters
// This is used both for defining macros (LHS) and using macros (RHS).
// Result:
// op = "()"
// args = vector of arguments (which are given comma-separated)
// In case of macro definition, all arguments must be of type "id". Pass 'defining' to check for that.
// namedArgs = dictionary of optional args
// In case of macro definition, dictionary values are default values that are used if the argument is not given
ExpressionPtr ParseMacroArgs(bool defining)
{
ConsumePunctuation(L"(");
auto macroArgs = make_shared<Expression>(GotToken().beginLocation, L"()");
for (;;)
{
let expr = ParseExpression(0, false); // this could be an optional arg (var = val)
if (defining && expr->op != L"id") // when defining we only allow a single identifier
Fail(L"argument identifier expected", expr->location);
if (expr->op == L"id" && GotToken().symbol == L"=")
{
let id = expr->id; // 'expr' gets resolved (to 'id') and forgotten
ConsumeToken();
let defValueExpr = ParseExpression(0, false); // default value
let res = macroArgs->namedArgs.insert(make_pair(id, make_pair(expr->location, defValueExpr)));
if (!res.second)
Fail(L"duplicate optional parameter '" + id + L"'", expr->location);
}
else
macroArgs->args.push_back(expr); // [0..]: position args
if (GotToken().symbol != L",")
break;
ConsumeToken();
}
ConsumePunctuation(L")");
return macroArgs;
}
map<wstring, pair<TextLocation,ExpressionPtr>> ParseRecordMembers()
{
// A dictionary is a map
// member identifier -> expression
// Macro declarations are translated into lambdas, e.g.
// F(A,B) = expr(A,B)
// gets represented in the dictionary as
// F = (A,B) => expr(A,B)
// where a lambda expression has this structure:
// op="=>"
// args[0] = parameter list (op="()", with args (all of op="id") and namedArgs)
// args[1] = expression with unbound arguments
// An array constructor of the form
// V[i:from..to] = expression of i
// gets mapped to the explicit array operator
// V = array[from..to] (i => expression of i)
map<wstring, pair<TextLocation,ExpressionPtr>> members;
auto idTok = GotToken();
while (idTok.kind == identifier)
{
let location = idTok.beginLocation; // for error message
let id = ConsumeIdentifier(); // the member's name
// optional array constructor
ExpressionPtr arrayIndexExpr, fromExpr, toExpr;
if (GotToken().symbol == L"[")
{
// X[i:from..to]
ConsumeToken();
arrayIndexExpr = ParseOperand(false); // 'i' name of index variable
if (arrayIndexExpr->op != L"id")
Expected(L"identifier");
ConsumePunctuation(L":");
fromExpr = ParseExpression(0, false); // 'from' start index
ConsumePunctuation(L"..");
toExpr = ParseExpression(0, false); // 'to' end index
ConsumePunctuation(L"]");
}
// optional macro args
let parameters = (GotToken().symbol == L"(") ? ParseMacroArgs(true/*defining*/) : ExpressionPtr(); // optionally, macro arguments
ConsumePunctuation(L"=");
auto rhs = ParseExpression(0, true/*can end at newline*/); // and the right-hand side
// if macro then rewrite it as an assignment of a lambda expression
if (parameters)
rhs = make_shared<Expression>(parameters->location, L"=>", parameters, rhs);
// if array then rewrite it as an assignment of a array-constructor expression
if (arrayIndexExpr)
{
// create a lambda expression over the index variable
let macroArgs = make_shared<Expression>(arrayIndexExpr->location, L"()", arrayIndexExpr); // wrap identifier in a '()' macro-args expression
let initLambdaExpr = make_shared<Expression>(arrayIndexExpr->location, L"=>", macroArgs, rhs); // [0] is id, [1] is body
rhs = make_shared<Expression>(location, L"array");
rhs->args.push_back(fromExpr); // [0] first index
rhs->args.push_back(toExpr); // [1] last index
rhs->args.push_back(initLambdaExpr); // [2] one-argument lambda to initialize
}
// insert
let res = members.insert(make_pair(id, make_pair(location, rhs)));
if (!res.second)
Fail(L"duplicate member definition '" + id + L"'", location);
// advance
idTok = GotToken();
if (idTok.symbol == L";")
idTok = GetToken();
}
return members;
}
// top-level parse function parses dictonary members
ExpressionPtr Parse()
{
let topMembers = ParseRecordMembers();
if (GotToken().kind != eof)
Fail(L"junk at end of source", GetCursor());
ExpressionPtr topDict = make_shared<Expression>(GetCursor(), L"[]");
topDict->namedArgs = topMembers;
return topDict;
}
// simple test function for use during development
static void Test()
{
let parserTest = L"a=1\na1_=13;b=2 // cmt\ndo = (print\n:train:eval) ; x = array[1..13] (i=>1+i*print.message==13*42) ; print = new PrintAction [ message = 'Hello World' ]";
ParseConfigString(parserTest)->Dump();
}
};
// globally exported functions to execute the parser
static ExpressionPtr Parse(SourceFile && sourceFile) { return Parser(move(sourceFile)).Parse(); }
ExpressionPtr ParseConfigString(wstring text) { return Parse(SourceFile(L"(command line)", text)); }
ExpressionPtr ParseConfigFile(wstring path) { return Parse(SourceFile(path)); }
}}} // namespaces

Просмотреть файл

@ -0,0 +1,103 @@
// ConfigParser.h -- config parser (syntactic only, that is, source -> Expression tree)
#pragma once
#include "Basics.h"
#include "ScriptableObjects.h"
#include "File.h"
#include <string>
#include <vector>
#include <map>
#include <memory>
namespace Microsoft { namespace MSR { namespace BS {
using namespace std;
// ---------------------------------------------------------------------------
// TextLocation -- holds a pointer into a source file
// ---------------------------------------------------------------------------
struct SourceFile // content of one source file (only in this header because TextLocation's private member uses it)
{
/*const*/ wstring path; // where it came from
/*const*/ vector<wstring> lines; // source code lines
SourceFile(wstring location, wstring text); // from string, e.g. command line
SourceFile(wstring path); // from file
};
struct TextLocation // position in the text. Lightweight value struct that we can copy around, even into dictionaries etc., for error messages
{
// source-code locations are given by line number, character position, and the source file
size_t lineNo, charPos; // line number and character index (0-based)
const SourceFile & GetSourceFile() const { return sourceFileMap[sourceFileAsIndex]; } // get the corresponding source-code line
// helpers for pretty-printing errors: Show source-code line with ...^ under it to mark up the point of error
static void PrintIssue(const vector<TextLocation> & locations, const wchar_t * errorKind, const wchar_t * kind, const wchar_t * what);
static void Trace(TextLocation, const wchar_t * traceKind, const wchar_t * op, const wchar_t * exprPath);
// construction
TextLocation() : lineNo(SIZE_MAX), charPos(SIZE_MAX), sourceFileAsIndex(SIZE_MAX) { } // default constructor constructs an unmissably invalid object
bool IsValid() const;
// register a new source file and return a TextPosition that points to its start
static TextLocation NewSourceFile(SourceFile && sourceFile);
private:
size_t sourceFileAsIndex; // source file is remembered in the value struct as an index into the static sourceFileMap[]
// the meaning of the 'sourceFile' index is global, stored in this static map
static vector<SourceFile> sourceFileMap;
};
// ---------------------------------------------------------------------------
// ConfigError -- all errors from processing the config files are reported as ConfigError
// ---------------------------------------------------------------------------
class ConfigError : public Microsoft::MSR::ScriptableObjects::ScriptingError
{
vector<TextLocation> locations; // error location (front()) and evaluation parents (upper)
public:
// Note: All our Error objects use wide strings, which we round-trip through runtime_error as utf8.
ConfigError(const wstring & msg, TextLocation where) : Microsoft::MSR::ScriptableObjects::ScriptingError(msra::strfun::utf8(msg)) { locations.push_back(where); }
// these are used in pretty-printing
TextLocation where() const { return locations.front(); } // where the error happened
virtual const wchar_t * kind() const = 0; // e.g. "warning" or "error"
// pretty-print this as an error message
void /*ScriptingError::*/PrintError() const { TextLocation::PrintIssue(locations, L"error", kind(), msra::strfun::utf16(what()).c_str()); }
void AddLocation(TextLocation where) { locations.push_back(where); }
};
// ---------------------------------------------------------------------------
// Expression -- the entire config is a tree of Expression types
// We don't use polymorphism here because C++ is so verbose...
// ---------------------------------------------------------------------------
struct Expression
{
wstring op; // operation, encoded as a string; 'symbol' for punctuation and keywords, otherwise used in constructors below ...TODO: use constexpr
wstring id; // identifier; op == "id", "new", "array", and "." (if macro then it also has args)
wstring s; // string literal; op == "s"
double d; // numeric literal; op == "d"
bool b; // boolean literal; op == "b"
typedef shared_ptr<struct Expression> ExpressionPtr;
vector<ExpressionPtr> args; // position-dependent expression/function args
map<wstring, pair<TextLocation,ExpressionPtr>> namedArgs; // named expression/function args; also dictionary members (loc is of the identifier)
TextLocation location; // where in the source code (for downstream error reporting)
// constructors
Expression(TextLocation location) : location(location), d(0.0), b(false) { }
Expression(TextLocation location, wstring op) : location(location), d(0.0), b(false), op(op) { }
Expression(TextLocation location, wstring op, double d, wstring s, bool b) : location(location), d(d), s(s), b(b), op(op) { }
Expression(TextLocation location, wstring op, ExpressionPtr arg) : location(location), d(0.0), b(false), op(op) { args.push_back(arg); }
Expression(TextLocation location, wstring op, ExpressionPtr arg1, ExpressionPtr arg2) : location(location), d(0.0), b(false), op(op) { args.push_back(arg1); args.push_back(arg2); }
// diagnostics helper: print the content
void Dump(int indent = 0) const;
};
typedef Expression::ExpressionPtr ExpressionPtr; // circumvent some circular definition problem
// access the parser through one of these two functions
ExpressionPtr ParseConfigString(wstring text);
ExpressionPtr ParseConfigFile(wstring path);
}}} // namespaces

Просмотреть файл

@ -0,0 +1,217 @@
// BrainScriptTest.cpp -- some tests
#define _CRT_SECURE_NO_WARNINGS // "secure" CRT not available on all platforms --add this at the top of all CPP files that give "function or variable may be unsafe" warnings
#include "Basics.h"
#include "BrainScriptEvaluator.h"
#include "BrainScriptParser.h"
#ifndef let
#define let const auto
#endif
namespace Microsoft { namespace MSR { namespace BS {
using namespace std;
using namespace msra::strfun;
// Note: currently this seems to be the master copy; got to check whether the other one was also changed
//extern wstring standardFunctions, computationNodes, commonMacros;
#if 1 // TODO: these may be newer, merge into Experimentalthingy
static wstring standardFunctions =
L"Print(value, format='') = new PrintAction [ what = value /*; how = format*/ ] \n"
L"Fail(msg) = new FailAction [ what = msg ] \n"
L"RequiredParameter(message) = Fail('RequiredParameter: ' + message) \n"
L"Format(value, format) = new StringFunction [ what = 'Format' ; arg = value ; how = format ] \n"
L"Replace(s, from, to) = new StringFunction [ what = 'Replace' ; arg = s ; replacewhat = from ; withwhat = to ] \n"
L"Substr(s, begin, num) = new StringFunction [ what = 'Substr' ; arg = s ; pos = begin ; chars = num ] \n"
L"Chr(c) = new StringFunction [ what = 'Chr' ; arg = c ] \n"
L"Floor(x) = new NumericFunction [ what = 'Floor' ; arg = x ] \n"
L"Length(x) = new NumericFunction [ what = 'Length' ; arg = x ] \n"
L"Ceil(x) = -Floor(-x) \n"
L"Round(x) = Floor(x+0.5) \n"
L"Abs(x) = if x >= 0 then x else -x \n"
L"Sign(x) = if x > 0 then 1 else if x < 0 then -1 else 0 \n"
L"Min(a,b) = if a < b then a else b \n"
L"Max(a,b) = if a > b then a else b \n"
L"Fac(n) = if n > 1 then Fac(n-1) * n else 1 \n"
;
static wstring computationNodes = // BUGBUG: optional args not working yet, some scope problem causing a circular reference
L"Mean(z, tag='') = new ComputationNode [ class = 'MeanNode' ; inputs = z /* ; tag = tag */ ]\n"
L"InvStdDev(z, tag='') = new ComputationNode [ class = 'InvStdDevNode' ; inputs = z /* ; tag = tag */ ]\n"
L"PerDimMeanVarNormalization(feat,mean,invStdDev, tag='') = new ComputationNode [ class = 'PerDimMeanVarNormalizationNode' ; inputs = feat:mean:invStdDev /* ; tag = tag */ ]\n"
L"Parameter(outD, inD, tag='parameter') = new ComputationNode [ class = 'LearnableParameterNode' ; outDim = outD ; inDim = inD /*; tag = tag*/ ]\n"
L"Input(dim,tag='features') = Parameter(dim,1,tag=tag) // TODO: for now \n"
L"RowSlice(firstRow, rows, features, tag='') = new ComputationNode [ class = 'RowSliceNode' ; inputs = features ; first = firstRow ; num = rows /* ; tag = tag */ ]\n"
L"Delay(in, delay, tag='') = new RecurrentComputationNode [ class = 'DelayNode' ; inputs = in ; deltaT = -delay /* ; tag = tag */ ]\n"
L"Sigmoid(z, tag='') = new ComputationNode [ class = 'SigmoidNode' ; inputs = z /* ; tag = tag */ ]\n"
L"Log(z, tag='') = new ComputationNode [ class = 'LogNode' ; inputs = z /* ; tag = tag */ ]\n"
L"CrossEntropyWithSoftmax(labels, outZ, tag='') = new ComputationNode [ class = 'CrossEntropyWithSoftmaxNode' ; inputs = labels:outZ /* ; tag = tag */ ]\n"
L"ErrorPrediction(labels, outZ, tag='') = new ComputationNode [ class = 'ErrorPredictionNode' ; inputs = labels:outZ /* ; tag = tag */ ]\n"
;
static wstring commonMacros = // TODO: rename rows and cols to inDim and outDim or vice versa, whichever it is
L"BFF(in, rows, cols) = [ B = Parameter(rows, 1/*init = fixedvalue, value = 0*/) ; W = Parameter(rows, cols) ; z = W*in+B ] \n"
L"SBFF(in, rows, cols) = [ Eh = Sigmoid(BFF(in, rows, cols).z) ] \n "
L"MeanVarNorm(feat) = PerDimMeanVarNormalization(feat, Mean(feat), InvStdDev(feat)) \n"
L"LogPrior(labels) = Log(Mean(labels)) \n"
;
#endif
void SomeTests()
{
try
{
// collecting all sorts of test cases here
const wchar_t * parserTests[] =
{
L"do = Parameter(13,42) * Input(42) + Parameter(13,1)"
,
L"do = Print(array [1..10] (i=>i*i))"
,
L"do = new PrintAction [ what = 'abc' ]"
,
L"do = Print(new StringFunction [ x = 13 ; y = 42 ; what = 'Format' ; how = '.2' ; arg = x*y ])"
,
L"do = Print(\"new StringFunction [ what = 'Format' ; how = '.2' ; arg = '13 > 42' ]\")"
,
L"do = new PrintAction [ what = if 13 > 42 || 12 > 1 then 'Hello World' + \"!\" else 'Oops?']"
,
L"i2s(i) = new StringFunction [ what = 'Format' ; arg = i ; how = '.2' ] ; do = Print('result=' + i2s((( [ v = (i => i + delta) ].v(5)))+13)) ; delta = 42 "
,
L"do = Print(1+2*3) : Print('hello'+' world')"
,
L"do = Print(Format( (13:(fortytwo:1):100), '')) ; fortytwo=42 "
,
L"do = Print(val) ; val=if !false then 42 else -+-++-13:[a='a';b=42]:+14; arr = array [1..10] (i => 2*i)"
,
L"do = Print(arg) ; N = 5 ; arr = array [1..N] (i => if i < N then arr[i+1]*i else N) ; arg = arr "
,
L"do = Print(val) ; val = [ v = (i => i + offset) ].v(42) ; offset = 13 "
,
// #12: DNN with recursion
L"do = Print(val) \n"
L"val = new NDLComputationNetwork [\n"
L" featDim=40*31 ; labelDim=9000 ; hiddenDim=2048 ; numHiddenLayers = 3 \n"
L" myFeatures = Input(featDim) ; myLabels = Input(labelDim) \n"
L" featNorm = MeanVarNorm(myFeatures) \n"
L" HiddenStack(layer) = if layer > 1 then SBFF(HiddenStack(layer - 1).Eh, hiddenDim, hiddenDim) else SBFF(featNorm, hiddenDim, featDim) \n"
L" outLayer = BFF(HiddenStack(numHiddenLayers).Eh, labelDim, hiddenDim) \n"
L" outZ = outLayer.z \n"
L" CE = CrossEntropyWithSoftmax(myLabels, outZ) \n"
L" Err = ErrorPrediction(myLabels, outZ) \n"
L" logPrior = LogPrior(myLabels) \n"
L" ScaledLogLikelihood = outZ - logPrior \n"
L"]\n"
,
// #13: factorial
L"do = Print(fac(5)) ; fac(i) = if i > 1 then fac(i-1)*i else 1 "
,
// #14: Fibonacci sequence with memoization
L"do = Print(fibs(10)) ; fibs(n) = [ vals = array[1..n] (i => if i < 3 then i-1 else vals[i-1]+vals[i-2]) ].vals[n] "
,
// #15: DNN with array
L"do = Print(val) \n"
L"val = new NDLComputationNetwork [\n"
L" featDim=40*31 ; labelDim=9000 ; hiddenDim=2048 ; numHiddenLayers = 3 \n"
L" myFeatures = Input(featDim, tag='features') ; myLabels = Input(labelDim, tag='labels') \n"
L" featNorm = MeanVarNorm(myFeatures) \n"
L" layers[layer:1..numHiddenLayers] = if layer > 1 then SBFF(layers[layer-1].Eh, hiddenDim, hiddenDim) else SBFF(featNorm, hiddenDim, featDim) \n"
L" outLayer = BFF(layers[numHiddenLayers].Eh, labelDim, hiddenDim) \n"
L" outZ = outLayer.z + Delay(outZ, 1) \n"
L" CE = CrossEntropyWithSoftmax(myLabels, outZ) \n"
L" Err = ErrorPrediction(myLabels, outZ) \n"
L" logPrior = LogPrior(myLabels) \n"
L" ScaledLogLikelihood = outZ - logPrior \n"
L"]\n"
,
// #16: windowed RNN
L"do = Print(val) \n"
L"val = new NDLComputationNetwork [ \n"
L" hiddenDim = 512 \n"
L" numHiddenLayers = 2 \n"
L" T = 3 // total context window \n"
L" \n"
L" // data sources \n"
L" featDim = 40 ; labelDim = 9000 \n"
L" myFeatures = Input(featDim) ; myLabels = Input(labelDim) \n"
L" \n"
L" // split the augmented input vector into individual frame vectors \n"
L" subframes[t:0..T - 1] = RowSlice(t * featDim, featDim, myFeatures) \n"
L" \n"
L" // hidden layers \n"
L" layers[layer:1..numHiddenLayers] = [ // each layer stores a dict that stores its hidden fwd and bwd state vectors \n"
L" // model parameters \n"
L" W_fwd = Parameter(hiddenDim, featDim) // Parameter(outdim, indim) \n"
L" W_bwd = if layer > 1 then Parameter(hiddenDim, hiddenDim) else Fail('no W_bwd') // input-to-hidden \n"
L" H_fwd = Parameter(hiddenDim, hiddenDim) // hidden-to-hidden \n"
L" H_bwd = Parameter(hiddenDim, hiddenDim) \n"
L" b = Parameter(hiddenDim, 1) // bias \n"
L" // shared part of activations (input connections and bias) \n"
L" z_shared[t:0..T-1] = (if layer > 1 \n"
L" then W_fwd * layers[layer - 1].h_fwd[t] + W_bwd * layers[layer - 1].h_bwd[t] \n"
L" else W_fwd * subframes[t] \n"
L" ) + b \n"
L" // recurrent part and non-linearity \n"
L" step(H, h, dt, t) = Sigmoid(if (t + dt >= 0 && t + dt < T) \n"
L" then z_shared[t] + H * h[t + dt] \n"
L" else z_shared[t]) \n"
L" h_fwd[t:0..T-1] = step(H_fwd, h_fwd, -1, t) \n"
L" h_bwd[t:0..T-1] = step(H_bwd, h_bwd, 1, t) \n"
L" ] \n"
L" // output layer --linear only at this point; Softmax is applied later \n"
L" outLayer = [ \n"
L" // model parameters \n"
L" W_fwd = Parameter(labelDim, hiddenDim) \n"
L" W_bwd = Parameter(labelDim, hiddenDim) \n"
L" b = Parameter(labelDim, 1) \n"
L" // output \n"
L" topHiddenLayer = layers[numHiddenLayers] \n"
L" centerT = Floor(T/2) \n"
L" z = W_fwd * topHiddenLayer.h_fwd[centerT] + W_bwd * topHiddenLayer.h_bwd[centerT] + b \n"
L" ] \n"
L" outZ = outLayer.z // we only want this one & don't care about the rest of this dictionary \n"
L" \n"
L" // define criterion nodes \n"
L" CE = CrossEntropyWithSoftmax(myLabels, outZ) \n"
L" Err = ErrorPrediction(myLabels, outZ) \n"
L" \n"
L" // define output node for decoding \n"
L" logPrior = LogPrior(myLabels) \n"
L" ScaledLogLikelihood = outZ - logPrior // before: Minus(CE.BFF.FF.P,logPrior,tag=Output) \n"
L"]\n"
,
L" \n" // this fails because dict is outside val; expression name is not local to it
L"do = Print(val) \n"
L"dict = [ outY = Input(13) ] ; val = new NDLComputationNetwork [ outZ = dict.outY \n"
L"]\n"
,
L"f(x,option='default') = Print(option); do = f(42,option='value')"
,
NULL
};
let first = 0; // 0 for all
bool oneOnly = first > 0;
for (size_t i = first; parserTests[i]; i++)
{
fprintf(stderr, "\n### Test %d ###\n\n", (int)i), fflush(stderr);
let parserTest = parserTests[i];
let expr = ParseConfigString(standardFunctions + computationNodes + commonMacros + parserTest);
//expr->Dump();
Do(expr);
if (oneOnly)
break;
}
}
catch (const ConfigError & err)
{
err.PrintError();
}
}
}}} // namespaces

400
BrainScript/Notes.txt Normal file
Просмотреть файл

@ -0,0 +1,400 @@
CNTK configuration language redesign (ongoing work)
====================================
F. Seide, August 2015
These are the original notes from before coding began. Basic ideas are correct, but may be a bit outdated.
- config specifies all configurable runtime objects and their initialization parameters
- basic concepts: dictionaries and runtime-object definitions
- basic syntactic elements:
- runtime object definitions // new classname initargsdictionary
- macro definition // M(x,y,z) = expression // expression uses x, y, and z
- expressions
- dictionaries // [ a=expr1 ; c=expr2 ]
- math ops and parentheses as usual // W*v+a, n==0
- conditional expression // if c then a else b
- array // a:b:c ; array [1..N] (i => f(i))
- syntax supports usual math and boolean expressions
- functions are runtime objects defined through macros, e.g. Replace(s,with,withwhat) = String [ from=s ; replacing=what ; with=withwhat ]
- config is parsed eagerly but evaluated lazily
- CNTK command line "configFile=conf.bs a=b c=d" expands to "new CNTK {content of conf.bs} + [ a=b ; c=d ]"
current issues
--------------
- syntax does not distinguish between dictionary members, intermediate variables, and actual parameter names
- dictionary editing needs to allow a.b.c syntax; and subtracting is not pretty as it needs dummy values -> maybe use a delete symbol? a=delete?
- missing: optional parameters to macros; and how this whole thing would work with MEL
grammar
-------
// --- top level defines a runtime object of class 'CNTK'
// example: new CNTK [ actions=train ; train=TrainAction [ ... ] ] // where "new CNTK [" is prepended by the command-line parser
$ = $dictitems // this is a dictionary without enclosing [ ... ] that defines instantiation args of CNTK class
// --- defining a runtime object and its parameters
// example: new ComputeNode [ class="Plus" ; arg1=A ; arg2=B ]
$newinstance = 'new' $classname $expr
where $expr must be a dictionary expression
$classname = $identifier
where $identifier is one of the known pre-defined C++ class names
// --- dictionaries are groups of key-value pairs.
// Dictionaries are expressions.
// Multiple dictionaries can be edited (dict1 + dict2) where dict2 members override dict1 ones of the same name.
// examples: [ arg1=A ; arg2=B ]
// dict1 + (if (dpt && layer < totallayers) then [ numiter = 5 ] else []) // overrides 'numiter' in 'dict1' if condition is fulfilled
$dictdef = '[' $dictitems ']'
$dictitems = $itemdef*
$itemdef = $paramdef // var=val
| $macrodef // macro(args)=expression
$paramdef = $identifier '=' $expr // e.g. numiter = 13
$macrodef = $identifier '(' $arg (',' $arg) ')' = $expr // e.g. sqr(x) = x*x
// --- expressions
// Expressions are what you'd expect. Infix operators those of C, with addition of '.*' '**' ':' '..'
// ML-style "let ... in" (expression-local variables) are possible but not super-pretty: [ a=13; b=42; res=a*b ].res
// There are infix ops for strings (concatenation) and dictionaries (editing).
$expr = $operand
| $expr $infixop $operand
| $expr '.' $memberref // dict.member TODO: fix this; memberrefs exist without '.'
where $expr is a dictionary
| $expr '(' $expr (',' $expr)* ')' // a(13) also: dict.a(13); note: partial application possible, i.e. macros may be passed as args and curried
where $expr is a macro
| $expr '[' $expr ']' // h_fwd[t]
where first $expr must be a array and second $expr a number (that must be an integer value)
$infixop = // highest precedence level
'*' // numbers; also magic short-hand for "Times" and "Scale" ComputeNodes
| '/' // numbers; Scale ComputeNode
| '.*' // ComputeNodes: component-wise product
| '**' // numbers (exponentiation, FORTRAN style!)
| '%' // numbers: remainder
// next lower precedence level
| '+' // numbers; ComputeNodes; strings; dictionary editing
| '-' // numbers; ComputeNodes; dictionary editing
// next lower precedence level
| '==' '!=' '<' '>' '<=' '>=' // applies to config items only; objects other than boxed primitive values are compared by object identity not content
// next lower precedence level
| '&&' // booleans
// next lower precedence level
| '||' | '^' // booleans
// next lower precedence level
| ':' // concatenate items and/or arrays --TODO: can arrays have nested arrays? Syntax?
$operand = $literal // "Hello World"
| $memberref // a
| $dictdef // [ a="Hello World" ]
| $newinstance // new ComputeNode [ ... ]
| ('-' | '+' | '!') $operand // -X+Y
| '(' $expr ')' // (a==b) || (c==d)
| $arrayconstructor // array [1..N] (i => i*i)
$literal = $number // built-in literal types are numeric, string, and boolean
| $string
| $boolconst
$number = // floating point number; no separate 'int' type, 'int' args are checked at runtime to be non-fractional
$string = // characters enclosed in "" or ''; no escape characters inside, use combinations of "", '', and + instead (TODO: do we need string interpolation?).
// Strings may span multiple lines (containing newlines)
$boolconst = 'true' | 'false'
$memberref = $identifier // will search parent scopes
$arrayconstructor = 'array' '[' $expr '..' $expr ']' '(' $identifier '=>' $expr ')' // array [1..N] (i => i*i)
where ^start ^end (int) ^index variable ^function of index variable
// --- predefined functions
// *All* functions are defined as macros that instantiate a runtime object. (The same is true for operators above, too, actually.)
// functions that really are macros that instantiate ComputeNodes:
// - Times(,), Plus(,), Sigmoid(), etc.
// numeric functions:
// - Floor() (for int division), Ceil(), Round() (for rounding), Abs(), Sign(), ...
// string functions:
// - Replace(s,what,withwhat), Str(number) (number to string), Chr(number) (convert Unicode codepoint to string), Format(fmt,val) (sprintf-like formatting with one arg)
// other:
// - Fail("error description") --will throw exception when executed; use this like assertion
dictionaries
------------
- dictionaries are key-value pairs; they are records or compound data structures for use inside the config file itself
- dictionaries are immutable and exist inside the parser but are not serialized to disk with a model --TODO: it might be needed to do that for MEL
- the argument to a runtime-object instantiation is also a dictionary
- the config file can access that dictionary's members directly from the runtime-object expression, for convenience
- intermediate variables that are only used to construct dictionary entries also become dictionary entries (no syntactic distinction) --TODO: should we distinguish them?
- macros are also dictionary members
- dictionary values are read out using dict.field syntax, where 'dict' is any expression that evaluates to a dictionary
- object instantiations will also traverse outer scopes to find values (e.g. precision, which is shared by many)
- runtime objects themselves are inputs to other runtime objects, but they cannot have data members that output values
- instead, output arguments use a proxy class ComputeNodeRef that can be used as a ComputeNode for input, and gets filled in at runtime
- dictionaries can be "edited" by "adding" (+) a second dictionary to it; items from the second will overwrite the same items in the first.
Subtracting a dictionary will remove all items in the second dict from the first.
This is used to allow for overriding variables on the command line. --TODO: not fully fleshed out how to access nested inner variables inside a dict
arrays
------
- another core data type is the array. Like dictionaries, arrays are immutable and exist inside the parser only.
- arrays are created at once in two ways
- 'array' expression:
array [1..N] (i => f(i)) // fake lambda syntax could be made real lambda; also could extend to multi-dim arrays
- ':' operator concatenates arrays and/or elements. Arrays are flattened.
1:2:3
- elements are read-accessed with index operator
X[i]
- example syntax of how one could define useful operators for arrays
- Append(seq,item) = seq : item
- Repeat(item,N) = array [1..N] (i => item)
- arrays with repetition can be created like this:
0.8 : array [1..3] (i => 0.2) : 0.05
or
0.8 : Repeat(0.2,3) : 0.05
- the array[] () argument looks like a C# lambda, but for now is hard-coded syntax (but with potential to be a true lambda in the future)
towards MEL
-----------
Model editing is now done in a functional manner, like this:
TIMIT_AddLayer = new EditAction [
currModelPath = "ExpDir\TrainWithPreTrain\dptmodel1\cntkSpeech.dnn"
newModelPath = "ExpDir\TrainWithPreTrain\dptmodel2\cntkSpeech.dnn.0"
model = LoadModel(currModelPath);
newModel = EditModel(model, [
// new items here
outZ = SBFF(model.outZ.INPUT, LABELDIM, outZ.INPUT.OUTDIM)
])
do = ( Dump(newModel, newModelPath + ".dump.txt")
: SaveModel(newModel, newModelPath) )
]
sample
------
// This sample is a modification of the original TIMIT_TrainSimpleNetwork.config and TIMIT_TrainNDLNetwork.config.
// The changes compared to the origina syntax are called out in comments.
stderr = ExpDir + "\TrainSimpleNetwork\log\log" // before: $ExpDir$\TrainSimpleNetwork\log\log
actions = TIMIT_TrainSimple // before: command = ... ('command' is singular, but this can be a sequence of actions)
// these values are used by several runtime-object instantiations below
precision = 'float' // before: precision = float
deviceId = DeviceNumber // before: $DeviceNumber$
#######################################
# TRAINING CONFIG (Simple, Fixed LR) #
#######################################
Repeat(val,count) = array [1..count] (i => val) // new: array helper to repeat a value (result is a array) (this would be defined in a library eventually)
TIMIT_TrainSimple = new TrainAction [ // new: added TrainAction; this is a class name of the underlying runtime object
// new: TrainAction takes three main parameters: 'source' -> 'model' -> 'optimizer' (-> indicating logical dependency)
//action = train // removed (covered by class name)
traceLevel = 1
// new: Model object; some parameters were moved into this
model = new Model [ // this is an input to TrainAction
modelPath = ExpDir + "\TrainSimpleNetwork\model\cntkSpeech.dnn" // before: $ExpDir$\TrainSimpleNetwork\model\cntkSpeech.dnn
// EXAMPLE 1: SimpleNetworkBuilder --TODO: do we even need a C++ class, or can we use a macro instead? Would make life easier re connecting inputs
network = new SimpleNetworkBuilder [ // before: SimpleNetworkBuilder = [
layerSizes = 792 : Repeat(512,3) : 183 // before: 792:512*3:183
layerTypes = 'Sigmoid' // before: no quotes
initValueScale = 1.0
applyMeanVarNorm = true
uniformInit = true
needPrior = true
// the following two belong into SGD, so they were removed here
//trainingCriterion = CrossEntropyWithSoftmax
//evalCriterion = ErrorPrediction
// new: connect to input stream from source; and expose the output layer
input = source.features.data // these are also ComputeNodeRefs, exposed by the source
output = ComputeNodeRef [ dim = source.labels.dim ] // SimpleNetworkBuilder will put top layer affine transform output (input to softmax) here
// criteria are configurable here; these are ComputeNodes created here
trainingCriterion = CrossEntropyWithSoftmax (source.labels.data, output)
evalCriterion = ErrorPrediction (source.labels.data, output)
// new: (and half-baked) define Input nodes
myFeatures=Input(featDim) // reader stream will reference this
myLabels=Input(labelDim)
]
// EXAMPLE 2: network from NDL (an actual config would contain one of these two examples)
network = new NDL [ // before: run=ndlCreateNetwork ; ndlCreateNetwork=[
featDim = myFeatures.dim // before: 792 hard-coded; note: myFeatures and myLabels are defined below
labelDim = myLabels.dim // before: 183 hard-coded
hiddenDim = 512
// input nodes
myFeatures=Input(featDim) // before: optional arg tag=feature
myLabels=Input(labelDim) // before: optional arg tag=label
// old
//# define network
//featNorm = MeanVarNorm(myFeatures)
//L1 = SBFF(featNorm,hiddenDim,featDim)
//L2 = SBFF(L1,hiddenDim,hiddenDim)
//L3 = SBFF(L2,hiddenDim,hiddenDim)
//CE = SMBFF(L3,labelDim,hiddenDim,myLabels,tag=Criteria)
//Err = ErrorPrediction(myLabels,CE.BFF.FF.P,tag=Eval)
//logPrior = LogPrior(myLabels)
//ScaledLogLikelihood=Minus(CE.BFF.FF.P,logPrior,tag=Output)
// new:
// Let's have the macros declared here for illustration (in the end, these would live in a library)
FF(X1, W1, B1) = W1 * X1 + B1 // before: T=Times(W1,X1) ; P=Plus(T, B1)
BFF(in, rows, cols) = [ // before: BFF(in, rows, cols) { ... }
B = Parameter(rows, init = fixedvalue, value = 0)
W = Parameter(rows, cols)
z = FF(in, w, b) // before: FF = ...; illegal now, cannot use same name again
]
SBFF(in, rowCount, colCount) = [ // before: SBFF(in,rowCount,colCount) { ... }
z = BFF(in, rowCount, colCount).z // before: BFF = BFF(in, rowCount, colCount)
Eh = Sigmoid(z)
]
// Macros are expressions. FF returns a ComputeNode; while BFF and SBFF return a dictionary that contains multiple named ComputeNode.
// new: define network in a loop. This allows parameterizing over the network depth.
numLayers = 7
layers = array [0..numLayers] ( layer =>
if layer == 0 then featNorm
else if layer == 1 then SBFF(layers[layer-1].Eh, hiddenDim, featDim)
else if layer < numLayers then SBFF(layers[layer-1].Eh, hiddenDim, hiddenDim)
else BFF(layers[layer-1].Eh, labelDim, hiddenDim)
)
outZ = layers[numLayers].z // new: to access the output value, the variable name (dictionary member) cannot be omitted
// alternative to the above: define network with recursion
HiddenStack(layer) = if layer > 1 then SBFF(HiddenStack(layer-1).Eh, hiddenDim, hiddenDim) else SBFF(featNorm, hiddenDim, featDim)
outZ = BFF(HiddenStack(numlayers).Eh, labelDim, hiddenDim)
// define criterion nodes
CE = CrossEntropyWithSoftmax(myLabels, outZ)
Err = ErrorPrediction(myLabels, outZ)
// define output node for decoding
logPrior = LogPrior(myLabels)
ScaledLogLikelihood = outZ - logPrior // before: Minus(CE.BFF.FF.P,logPrior,tag=Output)
]
]
// the SGD optimizer
optimizer = new SGDOptimizer [ // before: SGD = [
epochSize = 0
minibatchSize = 256 : 1024
learningRatesPerMB = 0.8 : Repeat(3.2,14) : 0.08 // (syntax change for repetition)
momentumPerMB = 0.9
dropoutRate = 0.0
maxEpochs = 25
// new: link to the criterion node
trainingCriterion = model.network.CE // (note: I would like to rename this to 'objective')
]
// The RandomizingSource performs randomization and mini-batching, while driving low-level random-access readers.
source = new RandomizingSource [ // before: reader = [
//readerType = HTKMLFReader // removed since covered by class name
// new: define what utterances to get from what stream sources
dataSetFile = ScpDir + "\TIMIT.train.scp.fbank.fullpath" // (new) defines set of utterances to train on; accepts HTK archives
streams = ( [ // This passes the 'features' and 'labels' runtime objects to the source, and also connects them to the model's Input nodes.
reader = features // random-access reader
input = model.network.myFeatures // Input node that this feeds into
]
: [
reader = labels
input = model.network.myLabels
] ) // note: ':' is array syntax. Parentheses are only for readability
readMethod = 'blockRandomize' // before: no quotes
miniBatchMode = 'Partial' // before: no quotes
randomize = 'Auto' // before: no quotes
verbosity = 1
// change: The following two are not accessed directly by the source, but indirectly through the 'streams' argument.
// They could also be defined outside of this dictionary. They are from the NDL, though.
// The 'RandomizingSource' does not know about features and labels specifically.
features = new HTKFeatReader [ // before: features = [
//dim = 792 // (moved to 'data' node)
scpFile = dataSetFile // HTK reader can share source's archive file that defines dataSet
data = new ComputeNodeRef [ dim = 792 ] // an input node the model can connect to; dimension is verified when files are opened
]
labels = new HTKMLFReader [ // before: labels = [
mlfFile = MlfDir + "\TIMIT.train.align_cistate.mlf.cntk" // before: $MlfDir$\TIMIT.train.align_cistate.mlf.cntk
//labelDim = 183 // (moved to 'data' node)
labelMappingFile = MlfDir + "\TIMIT.statelist" // before: $MlfDir$\TIMIT.statelist
data = new ComputeNodeRef [ dim = 183 ] // an input node the model can connect to; dimension is verified when reading statelist file
]
]
]
Example 2: truncated bidirectional RNN
--------------------------------------
network = new NDL [
// network parameters
hiddenDim = 512
numHiddenLayers = 6 // 6 hidden layers
T = 41 // total context window
// data sources
myFeatures = source.features.data
myLabels = source.labels.data
// derived dimensions
augmentedFeatDim = myFeatures.dim // feature arrays are context window frames stacked into a single long array
labelDim = myLabels.dim
centerT = Floor(T/2) // center frame to predict
featDim = Floor(augmentedFeatDim / T)
// split the augmented input vector into individual frame vectors
subframes = array [0..T-1] (t => RowSlice(t * featDim, featDim, myFeatures))
// hidden layers
// Hidden state arrays for all frames are stored in a array object.
layers = array [1..numHiddenLayers] (layer => [ // each layer stores a dictionary that stores its output hidden fwd and bwd state vectors
// model parameters
W_fwd = Parameter(hiddenDim, featDim) // Parameter(outdim, indim) --in_fwd.rows is an initialization parameter read from the dict
W_bwd = if layer > 1 then Parameter(hiddenDim, hiddenDim) else Fail("no W_bwd") // W denotes input-to-hidden connections
H_fwd = Parameter(hiddenDim, hiddenDim) // H denotes hidden-to-hidden lateral connections
H_bwd = Parameter(hiddenDim, hiddenDim)
b = Parameter(hiddenDim, 1) // bias
// shared part of activations (input connections and bias)
z_shared = array [0..T-1] (t => if layers > 1 then W_fwd * layers[layer-1].h_fwd[t] + W_bwd * layers[layer-1].h_bwd[t] + b // intermediate layer gets fed fwd and bwd hidden state
else W_fwd * subframes + b) // input layer reads frames directly
// recurrent part and non-linearity
neededT = if layer < numHiddenLayers then T else centerT+1 // last hidden layer does not require all frames
step(H,h,dt,t) = Sigmoid(if (t+dt > 0 && t+dt < T) then z_shared[t] + H * h[t+dt]
else z_shared[t])
h_fwd = array [0..neededT-1] (t => step(H_fwd, h_fwd, -1, t))
h_bwd = array [T-neededT..T-1] (t => step(H_bwd, h_bwd, 1, t))
])
// output layer --linear only at this point; Softmax is applied later
outZ = [
// model parameters
W_fwd = Parameter(labelDim, hiddenDim)
W_bwd = Parameter(labelDim, hiddenDim)
b = Parameter(labelDim, 1)
// output
topHiddenLayer = layers[numHiddenLayers]
z = W_fwd * topHiddenLayer.h_fwd[centerT] + W_bwd * topHiddenLayer.h_bwd[centerT] + b
].z // we only want this one & don't care about the rest of this dictionary
// define criterion nodes
CE = CrossEntropyWithSoftmax(myLabels, outZ)
Err = ErrorPrediction(myLabels, outZ)
// define output node for decoding
logPrior = LogPrior(myLabels)
ScaledLogLikelihood = outZ - logPrior // before: Minus(CE.BFF.FF.P,logPrior,tag=Output)
]

80
BrainScript/test.config Normal file
Просмотреть файл

@ -0,0 +1,80 @@
#
# test this with this command line:
# configFile=$(SolutionDir)BrainScript/test.config RunDir=$(SolutionDir)\Tests\Speech\RunDir DataDir=$(SolutionDir)\Tests\Speech\Data DeviceId=Auto
precision=float
command=speechTrain
deviceId=$DeviceId$
parallelTrain=false
speechTrain=[
action=train
modelPath=$RunDir$/models/cntkSpeech.dnn
deviceId=$DeviceId$
traceLevel=1
# inside here is the new stuff
ExperimentalNetworkBuilder=[
//deviceId = -21 ; precision = 'floax' // for now
layerSizes=363:512:512:132
trainingCriterion=CE
evalCriterion=Err
//layerTypes=Sigmoid
//initValueScale=1.0
//applyMeanVarNorm=true
//uniformInit=true
//needPrior=true
numHiddenLayers = 3
myFeatures = Input(layerSizes[0]) ; myLabels = Input(layerSizes[Length(layerSizes)-1])
featNorm = MeanVarNorm(myFeatures)
layers = array[1..numHiddenLayers] (layer => if layer > 1 then SBFF(layers[layer-1].Eh, layerSizes[layer], layerSizes[layer-1]) else SBFF(featNorm, layerSizes[layer], layerSizes[layer-1]))
outLayer = BFF(layers[numHiddenLayers].Eh, labelDim, hiddenDim)
outZ = outLayer.z
CE = CrossEntropyWithSoftmax(myLabels, outZ)
Err = ErrorPrediction(myLabels, outZ)
logPrior = LogPrior(myLabels)
ScaledLogLikelihood = outZ - logPrior
]
SGD=[
epochSize=20480
minibatchSize=64:256:1024:
learningRatesPerMB=1.0:0.5:0.1
numMBsToShowResult=10
momentumPerMB=0.9:0.656119
dropoutRate=0.0
maxEpochs=3
keepCheckPointFiles=true
AutoAdjust=[
reduceLearnRateIfImproveLessThan=0
loadBestModel=true
increaseLearnRateIfImproveMoreThan=1000000000
learnRateDecreaseFactor=0.5
learnRateIncreaseFactor=1.382
autoAdjustLR=AdjustAfterEpoch
]
clippingThresholdPerSample=1#INF
]
reader=[
readerType=HTKMLFReader
readMethod=blockRandomize
miniBatchMode=Partial
randomize=Auto
verbosity=0
features=[
dim=363
type=Real
scpFile=glob_0000.scp
]
labels=[
mlfFile=$DataDir$/glob_0000.mlf
labelMappingFile=$DataDir$/state.list
labelDim=132
labelType=Category
]
]
]

Просмотреть файл

@ -3,13 +3,14 @@ Microsoft Visual Studio Solution File, Format Version 12.00
# Visual Studio 2013
VisualStudioVersion = 12.0.21005.1
MinimumVisualStudioVersion = 10.0.40219.1
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CNTKMath", "Math\Math\Math.vcxproj", "{60BDB847-D0C4-4FD3-A947-0C15C08BCDB5}"
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CNTKMathDll", "Math\Math\Math.vcxproj", "{60BDB847-D0C4-4FD3-A947-0C15C08BCDB5}"
ProjectSection(ProjectDependencies) = postProject
{B3DD765E-694E-4494-BAD7-37BBF2942517} = {B3DD765E-694E-4494-BAD7-37BBF2942517}
EndProjectSection
EndProject
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CNTK", "MachineLearning\CNTK\CNTK.vcxproj", "{E6F26F9A-FF64-4F0A-B749-CD309EE357EE}"
ProjectSection(ProjectDependencies) = postProject
{928ABD1B-4D3B-4017-AEF1-0FA1B4467513} = {928ABD1B-4D3B-4017-AEF1-0FA1B4467513}
{33D2FD22-DEF2-4507-A58A-368F641AEBE5} = {33D2FD22-DEF2-4507-A58A-368F641AEBE5}
{D667AF32-028A-4A5D-BE19-F46776F0F6B2} = {D667AF32-028A-4A5D-BE19-F46776F0F6B2}
{9A2F2441-5972-4EA8-9215-4119FCE0FB68} = {9A2F2441-5972-4EA8-9215-4119FCE0FB68}
@ -17,6 +18,7 @@ Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CNTK", "MachineLearning\CNT
{014DA766-B37B-4581-BC26-963EA5507931} = {014DA766-B37B-4581-BC26-963EA5507931}
{62836DC1-DF77-4B98-BF2D-45C943B7DDC6} = {62836DC1-DF77-4B98-BF2D-45C943B7DDC6}
{1D5787D4-52E4-45DB-951B-82F220EE0C6A} = {1D5787D4-52E4-45DB-951B-82F220EE0C6A}
{DE3C54E5-D7D0-47AF-A783-DFDCE59E7937} = {DE3C54E5-D7D0-47AF-A783-DFDCE59E7937}
{E6646FFE-3588-4276-8A15-8D65C22711C1} = {E6646FFE-3588-4276-8A15-8D65C22711C1}
EndProjectSection
EndProject
@ -50,8 +52,9 @@ Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "LUSequenceReader", "DataRea
{60BDB847-D0C4-4FD3-A947-0C15C08BCDB5} = {60BDB847-D0C4-4FD3-A947-0C15C08BCDB5}
EndProjectSection
EndProject
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CNTKEval", "MachineLearning\CNTKEval\CNTKEval.vcxproj", "{482999D1-B7E2-466E-9F8D-2119F93EAFD9}"
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CNTKEvalDll", "MachineLearning\CNTKEval\CNTKEval.vcxproj", "{482999D1-B7E2-466E-9F8D-2119F93EAFD9}"
ProjectSection(ProjectDependencies) = postProject
{928ABD1B-4D3B-4017-AEF1-0FA1B4467513} = {928ABD1B-4D3B-4017-AEF1-0FA1B4467513}
{60BDB847-D0C4-4FD3-A947-0C15C08BCDB5} = {60BDB847-D0C4-4FD3-A947-0C15C08BCDB5}
EndProjectSection
EndProject
@ -84,9 +87,8 @@ Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "LibSVMBinaryReader", "DataR
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Linux build files", "Linux build files", "{3ED0465D-23E7-4855-9694-F788717B6533}"
ProjectSection(SolutionItems) = preProject
configure = configure
Makefile = Makefile
Makefile_kaldi.cpu = Makefile_kaldi.cpu
Makefile_kaldi.gpu = Makefile_kaldi.gpu
README = README
EndProjectSection
EndProject
@ -197,12 +199,59 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Data", "Data", "{5F733BBA-F
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "LSTM", "LSTM", "{19EE975B-232D-49F0-94C7-6F1C6424FB53}"
ProjectSection(SolutionItems) = preProject
Tests\Speech\LSTM\baseline.cpu.txt = Tests\Speech\LSTM\baseline.cpu.txt
Tests\Speech\LSTM\baseline.gpu.txt = Tests\Speech\LSTM\baseline.gpu.txt
Tests\Speech\LSTM\baseline.windows.cpu.txt = Tests\Speech\LSTM\baseline.windows.cpu.txt
Tests\Speech\LSTM\baseline.windows.gpu.txt = Tests\Speech\LSTM\baseline.windows.gpu.txt
Tests\Speech\LSTM\cntk.config = Tests\Speech\LSTM\cntk.config
Tests\Speech\LSTM\lstmp-3layer_WithSelfStab.ndl = Tests\Speech\LSTM\lstmp-3layer_WithSelfStab.ndl
Tests\Speech\LSTM\run-test = Tests\Speech\LSTM\run-test
Tests\Speech\LSTM\testcases.yml = Tests\Speech\LSTM\testcases.yml
EndProjectSection
EndProject
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "ParseConfig", "MachineLearning\ParseConfig\ParseConfig.vcxproj", "{7C4E77C9-6B17-4B02-82C1-DB62EEE2635B}"
EndProject
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CNTKComputationNetworkLib", "MachineLearning\CNTKComputationNetworkLib\CNTKComputationNetworkLib.vcxproj", "{928ABD1B-4D3B-4017-AEF1-0FA1B4467513}"
ProjectSection(ProjectDependencies) = postProject
{60BDB847-D0C4-4FD3-A947-0C15C08BCDB5} = {60BDB847-D0C4-4FD3-A947-0C15C08BCDB5}
EndProjectSection
EndProject
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CNTKSGDLib", "MachineLearning\CNTKSGDLib\CNTKSGDLib.vcxproj", "{DE3C54E5-D7D0-47AF-A783-DFDCE59E7937}"
ProjectSection(ProjectDependencies) = postProject
{928ABD1B-4D3B-4017-AEF1-0FA1B4467513} = {928ABD1B-4D3B-4017-AEF1-0FA1B4467513}
EndProjectSection
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "ParallelTraining", "ParallelTraining", "{5E666C53-2D82-49C9-9127-3FDDC321C741}"
ProjectSection(SolutionItems) = preProject
Tests\ParallelTraining\SimpleMultiGPU.config = Tests\ParallelTraining\SimpleMultiGPU.config
EndProjectSection
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Data", "Data", "{6D1353D6-F196-466F-B886-F16D48759B20}"
ProjectSection(SolutionItems) = preProject
Tests\ParallelTraining\Data\SimpleDataTrain.txt = Tests\ParallelTraining\Data\SimpleDataTrain.txt
Tests\ParallelTraining\Data\SimpleMapping.txt = Tests\ParallelTraining\Data\SimpleMapping.txt
EndProjectSection
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "NoQuantization", "NoQuantization", "{B6725C9F-A6D2-4269-9B74-7888A90F7884}"
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "SinglePrecision", "SinglePrecision", "{B27DD434-EECD-4EE0-A03B-1150EB87258E}"
ProjectSection(SolutionItems) = preProject
Tests\ParallelTraining\NoQuantization\SinglePrecision\baseline.cpu.txt = Tests\ParallelTraining\NoQuantization\SinglePrecision\baseline.cpu.txt
Tests\ParallelTraining\NoQuantization\SinglePrecision\baseline.gpu.txt = Tests\ParallelTraining\NoQuantization\SinglePrecision\baseline.gpu.txt
Tests\ParallelTraining\NoQuantization\SinglePrecision\baseline.windows.cpu.txt = Tests\ParallelTraining\NoQuantization\SinglePrecision\baseline.windows.cpu.txt
Tests\ParallelTraining\NoQuantization\SinglePrecision\run-test = Tests\ParallelTraining\NoQuantization\SinglePrecision\run-test
Tests\ParallelTraining\NoQuantization\SinglePrecision\testcases.yml = Tests\ParallelTraining\NoQuantization\SinglePrecision\testcases.yml
EndProjectSection
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "DoublePrecision", "DoublePrecision", "{A4884465-CFBB-4A64-A9DE-690E1A63EF7E}"
ProjectSection(SolutionItems) = preProject
Tests\ParallelTraining\NoQuantization\DoublePrecision\baseline.cpu.txt = Tests\ParallelTraining\NoQuantization\DoublePrecision\baseline.cpu.txt
Tests\ParallelTraining\NoQuantization\DoublePrecision\baseline.gpu.txt = Tests\ParallelTraining\NoQuantization\DoublePrecision\baseline.gpu.txt
Tests\ParallelTraining\NoQuantization\DoublePrecision\baseline.windows.cpu.txt = Tests\ParallelTraining\NoQuantization\DoublePrecision\baseline.windows.cpu.txt
Tests\ParallelTraining\NoQuantization\DoublePrecision\run-test = Tests\ParallelTraining\NoQuantization\DoublePrecision\run-test
Tests\ParallelTraining\NoQuantization\DoublePrecision\testcases.yml = Tests\ParallelTraining\NoQuantization\DoublePrecision\testcases.yml
EndProjectSection
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|x64 = Debug|x64
@ -268,6 +317,18 @@ Global
{CE429AA2-3778-4619-8FD1-49BA3B81197B}.Debug|x64.Build.0 = Debug|x64
{CE429AA2-3778-4619-8FD1-49BA3B81197B}.Release|x64.ActiveCfg = Release|x64
{CE429AA2-3778-4619-8FD1-49BA3B81197B}.Release|x64.Build.0 = Release|x64
{7C4E77C9-6B17-4B02-82C1-DB62EEE2635B}.Debug|x64.ActiveCfg = Debug|x64
{7C4E77C9-6B17-4B02-82C1-DB62EEE2635B}.Debug|x64.Build.0 = Debug|x64
{7C4E77C9-6B17-4B02-82C1-DB62EEE2635B}.Release|x64.ActiveCfg = Release|x64
{7C4E77C9-6B17-4B02-82C1-DB62EEE2635B}.Release|x64.Build.0 = Release|x64
{928ABD1B-4D3B-4017-AEF1-0FA1B4467513}.Debug|x64.ActiveCfg = Debug|x64
{928ABD1B-4D3B-4017-AEF1-0FA1B4467513}.Debug|x64.Build.0 = Debug|x64
{928ABD1B-4D3B-4017-AEF1-0FA1B4467513}.Release|x64.ActiveCfg = Release|x64
{928ABD1B-4D3B-4017-AEF1-0FA1B4467513}.Release|x64.Build.0 = Release|x64
{DE3C54E5-D7D0-47AF-A783-DFDCE59E7937}.Debug|x64.ActiveCfg = Debug|x64
{DE3C54E5-D7D0-47AF-A783-DFDCE59E7937}.Debug|x64.Build.0 = Debug|x64
{DE3C54E5-D7D0-47AF-A783-DFDCE59E7937}.Release|x64.ActiveCfg = Release|x64
{DE3C54E5-D7D0-47AF-A783-DFDCE59E7937}.Release|x64.Build.0 = Release|x64
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
@ -277,11 +338,15 @@ Global
{482999D1-B7E2-466E-9F8D-2119F93EAFD9} = {DD043083-71A4-409A-AA91-F9C548DCF7EC}
{60BDB847-D0C4-4FD3-A947-0C15C08BCDB5} = {DD043083-71A4-409A-AA91-F9C548DCF7EC}
{B3DD765E-694E-4494-BAD7-37BBF2942517} = {DD043083-71A4-409A-AA91-F9C548DCF7EC}
{928ABD1B-4D3B-4017-AEF1-0FA1B4467513} = {DD043083-71A4-409A-AA91-F9C548DCF7EC}
{DE3C54E5-D7D0-47AF-A783-DFDCE59E7937} = {DD043083-71A4-409A-AA91-F9C548DCF7EC}
{6CEE834A-8104-46A8-8902-64C81BD7928F} = {D45DF403-6781-444E-B654-A96868C5BE68}
{668BEED5-AC07-4F35-B3AE-EE65A7F9C976} = {D45DF403-6781-444E-B654-A96868C5BE68}
{0F30EBCF-09F3-4EED-BF54-4214BCE53FEC} = {D45DF403-6781-444E-B654-A96868C5BE68}
{DBB3C106-B0B4-4059-8477-C89528CEC1B0} = {D45DF403-6781-444E-B654-A96868C5BE68}
{C47CDAA5-6D6C-429E-BC89-7CA0F868FDC8} = {D45DF403-6781-444E-B654-A96868C5BE68}
{5E666C53-2D82-49C9-9127-3FDDC321C741} = {D45DF403-6781-444E-B654-A96868C5BE68}
{7C4E77C9-6B17-4B02-82C1-DB62EEE2635B} = {D45DF403-6781-444E-B654-A96868C5BE68}
{E6646FFE-3588-4276-8A15-8D65C22711C1} = {33EBFE78-A1A8-4961-8938-92A271941F94}
{1D5787D4-52E4-45DB-951B-82F220EE0C6A} = {33EBFE78-A1A8-4961-8938-92A271941F94}
{62836DC1-DF77-4B98-BF2D-45C943B7DDC6} = {33EBFE78-A1A8-4961-8938-92A271941F94}
@ -301,5 +366,9 @@ Global
{4BBF2950-3DBD-469A-AD57-6CACBEBAF541} = {C47CDAA5-6D6C-429E-BC89-7CA0F868FDC8}
{5F733BBA-FE83-4668-8F83-8B0E78A36619} = {C47CDAA5-6D6C-429E-BC89-7CA0F868FDC8}
{19EE975B-232D-49F0-94C7-6F1C6424FB53} = {C47CDAA5-6D6C-429E-BC89-7CA0F868FDC8}
{6D1353D6-F196-466F-B886-F16D48759B20} = {5E666C53-2D82-49C9-9127-3FDDC321C741}
{B6725C9F-A6D2-4269-9B74-7888A90F7884} = {5E666C53-2D82-49C9-9127-3FDDC321C741}
{B27DD434-EECD-4EE0-A03B-1150EB87258E} = {B6725C9F-A6D2-4269-9B74-7888A90F7884}
{A4884465-CFBB-4A64-A9DE-690E1A63EF7E} = {B6725C9F-A6D2-4269-9B74-7888A90F7884}
EndGlobalSection
EndGlobal

Просмотреть файл

@ -23,6 +23,8 @@
#include <nvml.h> // note: expected at "c:\Program Files\NVIDIA Corporation\GDK\gdk_win7_amd64_release\nvml\include" (Windows) and /the path you installed deployment kit/usr/include/nvidia/gdk (Linux)
#pragma comment (lib, "nvml.lib") // note: expected at "c:\Program Files\NVIDIA Corporation\GDK\gdk_win7_amd64_release\nvml\lib" (Windows) and /the path you installed deployment kit/usr/include/nvidia/gdk (Linux)
#include <vector>
#else
int bestGPUDummy = 42; // put something into this CPP, as to avoid a linker warning
#endif
#include "CommonMatrix.h" // for CPUDEVICE and AUTOPLACEMATRIX
@ -43,9 +45,6 @@
#include <memory>
#include "CrossProcessMutex.h"
#include "../../MachineLearning/CNTK/MPIWrapper.h"
extern Microsoft::MSR::CNTK::MPIWrapper *g_mpi;
// ---------------------------------------------------------------------------
// BestGpu class
@ -123,6 +122,9 @@ private:
// 0:2:3- an array of ids to use, (PTask will only use the specified IDs)
// *3 - a count of GPUs to use (PTask)
// All - Use all the GPUs (PTask)
#ifdef MATH_EXPORTS
__declspec(dllexport)
#endif
DEVICEID_TYPE DeviceFromConfig(const ConfigParameters& config)
{
static BestGpu* g_bestGpu = NULL;

Просмотреть файл

@ -242,7 +242,7 @@ bool DataReader<ElemType>::GetProposalObs(std::map<std::wstring, Matrix<ElemType
}
template<class ElemType>
void DataReader<ElemType>::SetSentenceSegBatch(Matrix<ElemType> &sentenceEnd, vector<MinibatchPackingFlag>& minibatchPackingFlag)
void DataReader<ElemType>::SetSentenceSegBatch(Matrix<float> &sentenceEnd, vector<MinibatchPackingFlag>& minibatchPackingFlag)
{
for (size_t i = 0; i < m_ioNames.size(); i++)
m_dataReader[m_ioNames[i]]->SetSentenceSegBatch(sentenceEnd, minibatchPackingFlag);
@ -259,7 +259,7 @@ template<class ElemType>
bool DataReader<ElemType>::GetMinibatchCopy(
std::vector<std::vector<std::pair<wstring, size_t>>>& uttInfo,
std::map<std::wstring, Matrix<ElemType>*>& matrices,
Matrix<ElemType>& sentenceBegin,
Matrix<float>& sentenceBegin,
std::vector<MinibatchPackingFlag>& minibatchPackingFlag)
{
bool ans = false;
@ -272,7 +272,7 @@ template<class ElemType>
bool DataReader<ElemType>::SetNetOutput(
const std::vector<std::vector<std::pair<wstring, size_t>>>& uttInfo,
const Matrix<ElemType>& outputs,
const Matrix<ElemType>& sentenceBegin,
const Matrix<float>& sentenceBegin,
const std::vector<MinibatchPackingFlag>& minibatchPackingFlag)
{
bool ans = false;

Просмотреть файл

@ -169,6 +169,20 @@ void File::GetLine(string& str)
str = fgetline(m_file);
}
// GetLines - get all lines from a file
template<typename STRING> static void FileGetLines(File & file, std::vector<STRING>& lines)
{
STRING line;
while (!file.IsEOF())
{
file.GetLine(line);
lines.push_back(line);
}
}
void File::GetLines(std::vector<std::wstring>& lines) { FileGetLines(*this, lines); };
void File::GetLines(std::vector<std::string>& lines) { FileGetLines(*this, lines); }
// Put a zero/space terminated wstring into a file
// val - value to write to the file
File& File::operator<<(const std::wstring& val)

Просмотреть файл

@ -8,6 +8,7 @@
#define _BASICS_H_
#include "basetypes.h" // TODO: gradually move over here all that's needed of basetypes.h, then remove basetypes.h.
#include "Platform.h"
#define TWO_PI 6.283185307f // TODO: find the official standards-confirming definition of this and use it instead
@ -25,61 +26,46 @@ namespace Microsoft { namespace MSR { namespace CNTK {
bool operator()(const std::wstring& left, const std::wstring& right) { return _wcsicmp(left.c_str(), right.c_str()) < 0; }
};
// ThrowFormatted() - template function to throw a std::exception with a formatted error string
template<class E>
__declspec_noreturn static inline void ThrowFormatted(const char * format, ...)
{
va_list args;
char buffer[1024];
va_start(args, format);
vsprintf(buffer, format, args);
throw E(buffer);
};
// if it receives a lonely std::string then throw that directly
template<class E>
__declspec_noreturn static inline void ThrowFormatted(const string & message) { throw E(message); }
// RuntimeError - throw a std::runtime_error with a formatted error string
#ifdef _MSC_VER
__declspec(noreturn)
#endif
static inline void RuntimeError(const char * format, ...)
{
va_list args;
char buffer[1024];
va_start(args, format);
vsprintf(buffer, format, args);
throw std::runtime_error(buffer);
};
static inline void RuntimeError(const string & message) { RuntimeError("%s", message.c_str()); }
// LogicError - throw a std::logic_error with a formatted error string
#ifdef _MSC_VER
__declspec(noreturn)
#endif
static inline void LogicError(const char * format, ...)
{
va_list args;
char buffer[1024];
va_start(args, format);
vsprintf(buffer, format, args);
throw std::logic_error(buffer);
};
static inline void LogicError(const string & message) { LogicError("%s", message.c_str()); }
// InvalidArgument - throw a std::logic_error with a formatted error string
#ifdef _MSC_VER
__declspec(noreturn)
#endif
static inline void InvalidArgument(const char * format, ...)
{
va_list args;
char buffer[1024];
va_start(args, format);
vsprintf(buffer, format, args);
throw std::invalid_argument(buffer);
};
static inline void InvalidArgument(const string & message) { InvalidArgument("%s", message.c_str()); }
template<class... _Types>
__declspec_noreturn static inline void RuntimeError(_Types&&... _Args) { ThrowFormatted<std::runtime_error>(forward<_Types>(_Args)...); }
template<class... _Types>
__declspec_noreturn static inline void LogicError(_Types&&... _Args) { ThrowFormatted<std::logic_error>(forward<_Types>(_Args)...); }
template<class... _Types>
__declspec_noreturn static inline void InvalidArgument(_Types&&... _Args) { ThrowFormatted<std::invalid_argument>(forward<_Types>(_Args)...); }
// Warning - warn with a formatted error string
static inline void Warning(const char * format, ...)
{
va_list args;
char buffer[1024];
va_start(args, format);
vsprintf(buffer, format, args);
};
static inline void Warning(const string & message) { Warning("%s", message.c_str()); }
// ----------------------------------------------------------------------------
// random collection of stuff we needed at some place
// ----------------------------------------------------------------------------
// TODO: maybe change to type id of an actual thing we pass in
// TODO: is this header appropriate?
template<class C> static wstring TypeId() { return msra::strfun::utf16(typeid(C).name()); }
// ----------------------------------------------------------------------------
// dynamic loading of modules --TODO: not Basics, should move to its own header
@ -91,7 +77,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
HMODULE m_hModule; // module handle for the writer DLL
std::wstring m_dllName; // name of the writer DLL
public:
Plugin() { m_hModule = NULL; }
Plugin() : m_hModule(NULL) { }
template<class STRING> // accepts char (UTF-8) and wide string
FARPROC Load(const STRING & plugin, const std::string & proc)
{
@ -99,13 +85,12 @@ namespace Microsoft { namespace MSR { namespace CNTK {
m_dllName += L".dll";
m_hModule = LoadLibrary(m_dllName.c_str());
if (m_hModule == NULL)
Microsoft::MSR::CNTK::RuntimeError("Plugin not found: %s", msra::strfun::utf8(m_dllName).c_str());
RuntimeError("Plugin not found: %s", msra::strfun::utf8(m_dllName).c_str());
// create a variable of each type just to call the proper templated version
return GetProcAddress(m_hModule, proc.c_str());
}
~Plugin(){}
// removed because this causes the exception messages to be lost (exception vftables are unloaded when DLL is unloaded)
// we do not unload because this causes the exception messages to be lost (exception vftables are unloaded when DLL is unloaded)
// ~Plugin() { if (m_hModule) FreeLibrary(m_hModule); }
};
#else
@ -114,11 +99,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
private:
void *handle;
public:
Plugin()
{
handle = NULL;
}
Plugin() : handle (NULL) { }
template<class STRING> // accepts char (UTF-8) and wide string
void * Load(const STRING & plugin, const std::string & proc)
{
@ -126,14 +107,10 @@ namespace Microsoft { namespace MSR { namespace CNTK {
soName = soName + ".so";
void *handle = dlopen(soName.c_str(), RTLD_LAZY);
if (handle == NULL)
RuntimeError("Plugin not found: %s", soName.c_str());
RuntimeError("Plugin not found: %s (error: %s)", soName.c_str(), dlerror());
return dlsym(handle, proc.c_str());
}
~Plugin() {
if (handle != NULL)
dlclose(handle);
}
~Plugin() { if (handle != NULL) dlclose(handle); }
};
#endif

Просмотреть файл

@ -22,10 +22,12 @@
#else
#define DATAREADER_API
#endif
#include "Basics.h"
#include "Matrix.h"
#include "commandArgUtil.h" // for ConfigParameters
#include <map>
#include <string>
#include "Basics.h"
namespace Microsoft { namespace MSR { namespace CNTK {
@ -84,7 +86,7 @@ public:
virtual void SetLabelMapping(const std::wstring&, const std::map<LabelIdType, LabelType>&) { NOT_IMPLEMENTED; };
virtual bool GetData(const std::wstring&, size_t, void*, size_t&, size_t) { NOT_IMPLEMENTED; };
virtual bool DataEnd(EndDataType) { NOT_IMPLEMENTED; };
virtual void SetSentenceSegBatch(Matrix<ElemType>&, vector<MinibatchPackingFlag>& ) { NOT_IMPLEMENTED; };
virtual void SetSentenceSegBatch(Matrix<float>&, vector<MinibatchPackingFlag>& ) { NOT_IMPLEMENTED; };
virtual void SetRandomSeed(unsigned seed = 0) { m_seed = seed; };
virtual bool GetProposalObs(std::map<std::wstring, Matrix<ElemType>*>*, const size_t, vector<size_t>&) { return false; }
virtual void InitProposals(std::map<std::wstring, Matrix<ElemType>*>*) { }
@ -103,7 +105,7 @@ public:
virtual bool GetMinibatchCopy(
std::vector<std::vector<std::pair<wstring, size_t>>>& /*uttInfo*/,
std::map<std::wstring, Matrix<ElemType>*>& /*matrices*/,
Matrix<ElemType>& /*sentenceBegin*/,
Matrix<float>& /*sentenceBegin*/,
std::vector<MinibatchPackingFlag>& /*minibatchPackingFlag*/)
{
return false;
@ -114,7 +116,7 @@ public:
virtual bool SetNetOutput(
const std::vector<std::vector<std::pair<wstring, size_t>>>& /*uttInfo*/,
const Matrix<ElemType>& /*outputs*/,
const Matrix<ElemType>& /*sentenceBegin*/,
const Matrix<float>& /*sentenceBegin*/,
const std::vector<MinibatchPackingFlag>& /*minibatchPackingFlag*/)
{
return false;
@ -225,7 +227,7 @@ public:
virtual bool GetMinibatchCopy(
std::vector<std::vector<std::pair<wstring, size_t>>>& uttInfo,
std::map<std::wstring, Matrix<ElemType>*>& matrices,
Matrix<ElemType>& sentenceBegin,
Matrix<float>& sentenceBegin,
std::vector<MinibatchPackingFlag>& minibatchPackingFlag);
// Sets the neural network output to the reader. This can be useful if some
@ -233,10 +235,10 @@ public:
virtual bool SetNetOutput(
const std::vector<std::vector<std::pair<wstring, size_t>>>& uttInfo,
const Matrix<ElemType>& outputs,
const Matrix<ElemType>& sentenceBegin,
const Matrix<float>& sentenceBegin,
const std::vector<MinibatchPackingFlag>& minibatchPackingFlag);
void SetSentenceSegBatch(Matrix<ElemType> & sentenceBegin, vector<MinibatchPackingFlag>& minibatchPackingFlag);
void SetSentenceSegBatch(Matrix<float> & sentenceBegin, vector<MinibatchPackingFlag>& minibatchPackingFlag);
void SetRandomSeed(int);

Просмотреть файл

@ -0,0 +1,140 @@
thoughts on DataReader redesign
===============================
current basic usage pattern:
- StartMinibatchLoop() sets the start and MB size
- GetMinibatch() fills matrices in a dictionary of named matrices
- sample code:
std::map<std::wstring, Matrix<ElemType>*> matrices;
matrices[featureNames[0]] = &featuresMatrix;
matrices[labelNames[0]] = &labelsMatrix;
dataReader.StartMinibatchLoop(mbSize, epoch, epochSize);
while (dataReader.GetMinibatch(matrices))
{
Matrix<ElemType>& features = *matrices[featureNames[0]];
Matrix<ElemType>& labels = *matrices[labelNames[0]];
// no function called at end, implied in GetMinibatch()
issues with current data reader design:
- monolithic, combines all (or not) of these:
- paging in data (incl. format parsing)
- randomization
- caching
- packing of parallel streams
- prefetch (in the original DBN.exe version of the HTK reader)
- minibatch decimation in presence of MPI (done in a way that avoids to read data that is not needed by a node)
- multiple streams must match in their timing frame-by-frame
- kills sequence-to-sequence
- currently circumvented by paring multiple networks
goals:
- remove time-synchronity limitation
- which means that the interface must separate the notion of frames and utterances
- break into composable blocks
- hopefully, people in the future will only have to implement the data paging
- note: packing is not a reader function; nodes themselves may determine packing for each minibatch
- more abstract notion of 'utterance' e.g. include variable-size images (2D) and video (3D)
- seems we canb keep the existing DataReader interface, but with extensions and a new implementation
feature details that must be considered/covered:
- augmentation of context frames
- some utterances are missing
- multi-lingual training (multi-task learning where each utterance only has labels for one task)
- realignment may fail on some utterances
- should support non-matrix data, e.g. lattices
- maybe we can improve efficiency of decimated minibatch reading (current approach from DBN.exe is not optimally load-balanced)
thinking out loud on how we may proceed (high level):
- basic unit of thought is the utterance, not the minibatch
- a minibatch is a set of utterances
- framewise CE training: each frame is an utterance; the N frames are batched into N streams of 1 frame
- note: design must be non-wasteful in this important special case
- an utterance should be understood more generally as a fixed or variable-dimension N-dimensional tensor,
including images (2D tensor, of possibly variable size) and even video (3D tensor).
And 'utterance length' generalizes to image dimensions as well. Everything that's variable length.
- interface Sequencer
- determines the sequence of utterances and grouping into minibatches
- by driving on utterance level, different feature streams with mismatching timing are not a concern of the Sequencer
- owns knowlegde of blocks
- provides caching control information, that is, when to release data from memory
- for frame mode, there must be some form of translation between utterances and frames, so that we can cache utterances while randomizing over frames
- does NOT actually read data; only provides descriptors of what to read, which are passed to pagers
- DataReader class does the reading
- in eval mode, there is also a DataWriter
- class RandomSequencer
- performs block randomization, based on one user-selected data pager
- for SGD
- class RandomFrameSequencer
- treats frames of utterances into individual utterances and randomizes those (for CE training of DNNs or other windows models)
- class LinearSequencer
- returns data in original sequence
- for evaluation
- interface DataPager
- random access to page in utterances
- specified by a descriptor obtained from the Sequencer
- knowledge of how to parse input data formats is in these pagers
- data assumed immutable
- examples:
- HTK features
- HTK labels from MLF (also: labels stored in feature format, for reduced startup time)
- Python adapter
- lightweight agreement between DataPager and Sequencer:
- pager provides block-forming relevant information, such that the reading of data consecutively in each block will be optimal;
sequencer will ask one user-selected pager to provide this information as a basis for block randomization
- class CachedDataPager
- TODO: think this through:
- are DataPagers driven in blocks? That would be the unit of caching
- releasing a block from cache must be an explicit function
- maybe that helper class needs to do that
- or we use a ref count for utterances to control releasing of blocks? Could be expensive, since invidiual speech frames can be utterances (DNN/CE). It's only a refcount of 0 or 1
- should we call this DataPageReader or DataBlockReader?
- let's call them Pager for now (need to change because name has a problem with reading vs. writing)
- class DataReader
- outer layer of the new structure
- designed to support reading data streams that that have mismatching utterance lengths
- there is only one DataReader instance that handles all utterance-data streams (including mismatching lenths)
- takes a reference to one user-specified sequencer
- takes ownership of one or more user-supplied pagers
- after construction, the above are only accessed through the DataReader
- a nested hierarchy of DataReaders implement specific functionaliaty
class CachingDataReader
- wraps a DataReader with caching--this is what one would use when randomizing (not needed for evaluation)
class PrefetchingDataReader
- wraps a DataReader and reads ahead on a parallel thread
- TODO: so where does the sequencer run?? Or does sequencer provides a FIFO of minibatches (lookahead)?
maybe sequence info is routed through the prefetch for everything? Consider that we also need to do writing, so this becomes weird
in that one would always have to access the sequencer through the DataReader (in order to get the correctly delayed sequencing information)
class BatchingDataReader?
- meant to batch utterances into streams
- NO: this should not be a DataReader, as this is a network function. But we may have supporting code in the reader interface or a reader helper class
- instead, there should be a Minibatch batcher class that (re-)batches and reshuffles minibatches (this could be a ComputeNode, actually)
- this new DataReader differs (extends) the current DataReader as:
- GetMinibatch() has to return utterance and length/packing information for every minibatch
- minibatches must also carry their own sequencing information (utterance ids); this can then be used for data writing
- we may want to rethink the glue between reading and Input nodes. Maybe Input nodes can know about readers?
- how to set up the whole thing:
- create all desired data pagers
- create a Sequencer of the desired type
- e.g. random utterance, random frame, non-random
- pass it one user-selected data pager to let it determining how data is grouped in blocks
- create the desired DataReader by passing it the readers and the sequencer
- there may be a hierarchy of nested DataReaders, to do caching and prefetching
- use it mostly as before
- Note: sequencer information can only be accessed through the DataReader.
on writing:
- writing is used for evaluation, and also for data conversion
- DataReader returns minibatches that carry their utterance information (i.e. utterance ids)
- class DataWriter
- new SaveData() overload takes an output minibatch complete with utterance information
TODO: we could make the two interfaces a little more symmetric w.r.t. function naming
[fseide 9/2015]

Просмотреть файл

@ -25,10 +25,10 @@
#include "Basics.h"
#include "Matrix.h"
#include "commandArgUtil.h" // for ConfigParameters
#include <map>
#include <string>
namespace Microsoft { namespace MSR { namespace CNTK {
// type of data in this section

Просмотреть файл

@ -4,6 +4,8 @@
// </copyright>
//
#pragma once
#include "Basics.h"
#include <stdio.h>
#include <string>
#include <vector>
@ -15,6 +17,8 @@
#include <unistd.h>
#endif
#include "fileutil.h" // for f{ge,pu}t{,Text}()
#include <fstream> // for LoadMatrixFromTextFile() --TODO: change to using this File class
#include <sstream>
namespace Microsoft{ namespace MSR { namespace CNTK {
@ -127,6 +131,8 @@ public:
void GetLine(std::wstring& str);
void GetLine(std::string& str);
void GetLines(std::vector<std::wstring>& lines);
void GetLines(std::vector<std::string>& lines);
// put operator for basic types
template <typename T>
@ -238,6 +244,74 @@ public:
return *this;
}
// Read a matrix stored in text format from 'filePath' (whitespace-separated columns, newline-separated rows),
// and return a flat array containing the contents of this file in column-major format.
// filePath: path to file containing matrix in text format.
// numRows/numCols: after this function is called, these parameters contain the number of rows/columns in the matrix.
// returns: a flat array containing the contents of this file in column-major format
// NOTE: caller is responsible for deleting the returned buffer once it is finished using it.
// TODO: change to return a std::vector<ElemType>; solves the ownership issue
// This function does not quite fit here, but it fits elsewhere even worse. TODO: change to use File class!
template<class ElemType>
static vector<ElemType> LoadMatrixFromTextFile(const std::string filePath, size_t& numRows, size_t& numCols)
{
size_t r = 0;
size_t numColsInFirstRow = 0;
// NOTE: Not using the Microsoft.MSR.CNTK.File API here because it
// uses a buffer of fixed size, which doesn't allow very long rows.
// See fileutil.cpp fgetline method (std::string fgetline (FILE * f) { fixed_vector<char> buf (1000000); ... })
std::ifstream myfile(filePath);
// load matrix into vector of vectors (since we don't know the size in advance).
std::vector<std::vector<ElemType>> elements;
if (myfile.is_open())
{
std::string line;
while (std::getline(myfile, line))
{
// Break on empty line. This allows there to be an empty line at the end of the file.
if (line == "")
break;
istringstream iss(line);
ElemType element;
int numElementsInRow = 0;
elements.push_back(std::vector<ElemType>());
while (iss >> element)
{
elements[r].push_back(element);
numElementsInRow++;
}
if (r == 0)
numColsInFirstRow = numElementsInRow;
else if (numElementsInRow != numColsInFirstRow)
RuntimeError("The rows in the provided file do not all have the same number of columns: " + filePath);
r++;
}
myfile.close();
}
else
RuntimeError("Unable to open file");
numRows = r;
numCols = numColsInFirstRow;
vector<ElemType> array(numRows * numCols);
// Perform transpose when copying elements from vectors to ElemType[],
// in order to store in column-major format.
for (int i = 0; i < numCols; i++)
{
for (int j = 0; j < numRows; j++)
array[i * numRows + j] = elements[j][i];
}
return array;
}
operator FILE*() const { return m_file; }
};

Просмотреть файл

@ -222,7 +222,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
}
// for raw pointer
template<typename ElemType>
template<class ElemType>
void AllReduce(ElemType* pData, size_t nData)
{
if ((NumNodesInUse() > 1 && (Communicator() != MPI_COMM_NULL)))
@ -231,7 +231,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
}
}
template<typename ElemType>
template<class ElemType>
void Bcast(ElemType* pData, size_t nData, size_t srcRank)
{
if ((NumNodesInUse() > 1) && (Communicator() != MPI_COMM_NULL))
@ -247,3 +247,5 @@ namespace Microsoft { namespace MSR { namespace CNTK {
}
};
}}}
extern Microsoft::MSR::CNTK::MPIWrapper *g_mpi;

Просмотреть файл

@ -9,6 +9,16 @@
#define __UNIX__
#endif
// ===========================================================================
// stuff to avoid compiler warnings
// ===========================================================================
#ifdef _MSC_VER
#define __declspec_noreturn __declspec(noreturn)
#else
#define __declspec_noreturn
#endif
// ===========================================================================
// emulation of some MSVC proprietary CRT
// ===========================================================================
@ -53,10 +63,6 @@ typedef void* HANDLE;
#define VOID void
#define CONST const
//standard library conversion
//#define min std::min
#define hash_map unordered_map
//macro conversion
#define __forceinline inline
//string and io conversion

Просмотреть файл

@ -0,0 +1,513 @@
// BrainScriptObjects.h -- objects that the config parser operates on
#pragma once
#include "Basics.h"
#include <memory> // for shared_ptr<>
#include <functional> // for function<>
namespace Microsoft { namespace MSR { namespace ScriptableObjects {
using namespace std;
using namespace msra::strfun; // for wstrprintf()
using namespace Microsoft::MSR::CNTK; // for stuff from Basics.h
// -----------------------------------------------------------------------
// ScriptingError -- base class for any errors thrown by scripting
// It's a runtime_error with an additional virtual function PrintError().
// -----------------------------------------------------------------------
class ScriptingError : public runtime_error
{
public:
template<typename M>
ScriptingError(const M & msg) : runtime_error(msg) { }
virtual void PrintError() const = 0;
};
// -----------------------------------------------------------------------
// Object -- common base class for objects that can be used in config files
// -----------------------------------------------------------------------
// All values that can be used in config files
// - are heap objects
// - primitives are wrapped
// - object pointers are ref-counted shared_ptr, wrapped in ConfigValuePtr (see BrainScriptEvaluator.h)
// - derive from Object (outside classes get wrapped)
//
// This code supports three kinds of value types:
// - self-defined classes -> derive from Object, e.g. Expression
// - classes defined outside -> wrap in a BoxOf object, e.g. String = BoxOf<wstring>
// - C++ primitives like 'double' -> wrap in a Wrapper first then in a BoxOf, e.g. Number = BoxOf<Wrapped<double>>
struct Object { virtual ~Object() { } };
// indicates that the object has a name should be set from the expression path
struct HasName { virtual void SetName(const wstring & name) = 0; };
// -----------------------------------------------------------------------
// Wrapped<T> -- wraps non-class primitive C++ type into a class, like 'double'.
// (It can also be used for class types, but better use BoxOf<> below directly.)
// -----------------------------------------------------------------------
template<typename T> class Wrapped
{
T value; // meant to be a primitive type
public:
operator const T&() const { return value; }
operator T&() { return value; }
Wrapped(T value) : value(value) { }
T & operator=(const T & newValue) { value = newValue; }
};
typedef Wrapped<double> Double;
typedef Wrapped<bool> Bool;
// -----------------------------------------------------------------------
// BoxOf<T> -- wraps a pre-defined type, e.g. std::wstring, to derive from Object.
// BoxOf<T> can dynamic_cast to T (e.g. BoxOf<wstring> is a wstring).
// -----------------------------------------------------------------------
template<class C>
class BoxOf : public Object, public C
{
public:
#if 1
template<class... _Types> BoxOf(_Types&&... _Args) : C(forward<_Types>(_Args)...) { }
#else
// TODO: change this to variadic templates, then we can instantiate everything we need through this
BoxOf(const C & val) : C(val) { }
BoxOf(){}
#endif
};
// -----------------------------------------------------------------------
// String -- a string in config files
// Can cast to wstring (done in a way that ConfigValuePtr can also cast to wstring).
// -----------------------------------------------------------------------
typedef BoxOf<wstring> String;
// -----------------------------------------------------------------------
// ComputationNodeObject -- the 'magic' class that our parser understands for infix operations
// TODO: unify with ComputationNodeBase
// -----------------------------------------------------------------------
class ComputationNodeObject : public Object { }; // a base class for all nodes (that has no template parameter)
// -----------------------------------------------------------------------
// HasToString -- trait to indicate an object can print their content
// Derive from HasToString() and implement ToString() method.
// FormatConfigValue() will then return ToString().
// -----------------------------------------------------------------------
struct HasToString
{
virtual wstring ToString() const = 0;
// some string helpers useful for ToString() operations of nested structures
// TODO: move these out from this header into some more general place (I had to move them here because otherwise CNTKEval failed to compile)
static wstring IndentString(wstring s, size_t indent)
{
const wstring prefix(indent, L' ');
size_t pos = 0;
for (;;)
{
s.insert(pos, prefix);
pos = s.find(L'\n', pos + 2);
if (pos == wstring::npos)
return s;
pos++;
}
}
static wstring NestString(wstring s, wchar_t open, bool newline, wchar_t close)
{
wstring result = IndentString(s, 2);
if (newline) // have a new line after the open symbol
result = L" \n" + result + L"\n ";
else
result.append(L" ");
result.front() = open;
result.back() = close;
return result;
}
};
// -----------------------------------------------------------------------
// WithTag -- trait to give an object a tag string
// -----------------------------------------------------------------------
class WithTag
{
wstring m_tag;
public:
WithTag(){}
void SetTag(const wstring & tag) { m_tag = tag; }
const wstring & GetTag() const { return m_tag; }
};
// =======================================================================
// ConfigValuePtr -- shared pointer to a config value
// =======================================================================
// A ConfigValuePtr holds the value of a configuration variable.
// - specifically, it holds a shared_ptr to a strongly typed C++ object
// - ConfigValuePtrs are immutable when consumed.
//
// All configuration values, that is, values that can be held by a ConfigValuePtr, derive from BS::Object.
// To get a shared_ptr<T> of an expected type T, type-cast the ConfigValuePtr to it.
// To get the value of a copyable type like T=double or wstring, type-cast to T directly.
//
// ConfigValuePtrs are evaluated on-demand upon first retrieval:
// - initially, a ConfigValuePtr would hold a Thunk; that is, a lambda that computes (resolves) the value
// - upon first use, the Thunk is invoked to compute the value, which will then *replace* the Thunk
// - any consumer of a ConfigValuePtr will only ever see the resolved value, since any access for consumption will force it to be resolved
// - a resolved ConfigValuePtr is immutable
//
// On-demand evaluation is critical to the semantics of this entire configuration system.
// A configuration is but one big expression (of nested records), but some evaluations cause side effects (such as saving a model), and some expressions may not even be in use at all.
// Thus, we must use on-demand evaluation in order to ensure that side effects are only executed when desired.
//
// Further, to ensure a Thunk is executed at most once (otherwise we may get the same side-effect multiple times),
// an unresolved ConfigValuePtr can only live in a single place. This means,
// - an unresolved ConfigValuePtr (i.e. one holding a Thunk) cannot be copied (while resolved ones are immutable and can be copied freely)
// - it can be moved (std::move()) during creation
// - after creation, it should only live in a known location from which it can be retrieved; specifically:
// - ConfigRecord entries
// - ConfigArrays elements
// - ConfigLambdas (default values of named arguments)
// TODO: separate this out from BrainScript to an interface that still does type casts--possible?
class ConfigValuePtr : public shared_ptr<Object>
{
function<void(const wstring &)> failfn; // function to call in case of failure due to this value
wstring expressionName; // the expression name reflects the path to reach this expression in the (possibly dynamically macro-expanded) expression tree. Used for naming ComputationNodes.
// Thunk for resolving a value. This Object represents a function that returns a ConfigValuePtr; call to resolve a deferred value
class Thunk : public Object
{
function<ConfigValuePtr()> f; // the function to compute the value
bool currentlyResolving; // set during resolution phase, to detect circular references
function<void(const wstring &)> failfn; // function to call in case of failure due to this value
public:
Thunk(function<ConfigValuePtr()> f, const function<void(const wstring &)> & failfn) : f(f), failfn(failfn), currentlyResolving(false) { }
ConfigValuePtr ResolveValue()
{
if (currentlyResolving) // detect circular references (infinite recursion)
failfn(L"circular reference (expression to compute identifier's value uses the identifier's value)");
currentlyResolving = true; // can't run from inside ourselves
return f();
// no need to reset currentlyResolving because this object gets replaced and thus deleted anyway
}
};
Thunk * GetThunk() const { return dynamic_cast<Thunk*>(get()); } // get Thunk object or nullptr if already resolved
public:
// --- assignment and copy/move constructors
ConfigValuePtr() {} // (formally needed somehow)
ConfigValuePtr(const shared_ptr<Object> & p, const function<void(const wstring &)> & failfn, const wstring & expressionName) : shared_ptr<Object>(p), failfn(failfn), expressionName(expressionName) { }
//ConfigValuePtr(const function<ConfigValuePtr()> & f, TextLocation location, const wstring & expressionName) : shared_ptr<Object>(make_shared<Thunk>(f, location)), location(location), expressionName(expressionName) { }
static ConfigValuePtr MakeThunk(const function<ConfigValuePtr()> & f, const function<void(const wstring &)> & failfn, const wstring & expressionName)
{
return ConfigValuePtr(make_shared<Thunk>(f, failfn), failfn, expressionName);
}
// TODO: somehow the constructor overload from Thunk function fails to compile, so for now use MakeThunk instead
ConfigValuePtr(const ConfigValuePtr & other) { *this = other; }
ConfigValuePtr(ConfigValuePtr && other) { *this = move(other); }
void operator=(const ConfigValuePtr & other)
{
if (other.GetThunk()) // unresolved ConfigValuePtrs are not copyable, only movable
Microsoft::MSR::CNTK::LogicError("ConfigValuePtr::operator=() on unresolved object; ConfigValuePtr is not assignable until resolved");
(shared_ptr<Object>&)*this = other;
failfn = other.failfn;
expressionName = other.expressionName;
}
void operator=(ConfigValuePtr && other)
{
failfn = move(other.failfn);
expressionName = move(other.expressionName);
(shared_ptr<Object>&)*this = move(other);
}
void Fail(const wstring & msg) const { failfn(msg); }
const function<void(const wstring &)> & GetFailFn() const { return failfn; } // if you need to pass on the fail function
// --- retrieving values by type cast
// access as a reference, that is, as a shared_ptr<T> --use this for Objects
template<typename T> operator shared_ptr<T>() const { return AsPtr<T>(); }
// access as a (const & to) value --use this for primitive types (also works to get a const wstring & from a String)
template<typename T> operator T() const { return AsRef<T>(); }
// Linux gcc barfs on this ^^ for 'us = (double)((wstring)arg).size();' due to some ambiguity error (while it works fine with Visual Studio).
// If you encounter this, instead say 'us = (double)((wstring&)arg).size();' with a &
operator double() const { return AsRef<Double>(); }
operator float() const { return (float) AsRef<Double>(); }
operator bool() const { return AsRef<Bool>(); }
template<typename INT> INT AsInt() const
{
double val = AsRef<Double>();
INT ival = (INT)val;
const wchar_t * type = L"size_t";
const char * t = typeid(INT).name(); t;
// TODO: there is some duplication of type checking; can we unify that?
if (ival != val)
Fail(wstrprintf(L"expected expression of type %ls instead of floating-point value %f", type, val));
return ival;
}
operator size_t() const { return AsInt<size_t>(); }
operator int() const { return AsInt<int>(); }
// --- access functions
template<class C>
bool Is() const
{
EnsureIsResolved();
const auto p = dynamic_cast<C*>(get());
return p != nullptr;
}
template<class C>
const C & AsRef() const // returns reference to what the 'value' member. Configs are considered immutable, so return a const&
{
// TODO: factor these lines into a separate function
// Note: since this returns a reference into 'this', you must keep the object you call this on around as long as you use the returned reference
EnsureIsResolved();
//const C * wanted = (C *) nullptr; const auto * got = get(); wanted; got; // allows to see C in the debugger
const auto p = dynamic_cast<C*>(get());
if (p == nullptr) // TODO: can we make this look the same as TypeExpected in BrainScriptEvaluator.cpp? We'd need the type name
Fail(L"config member has wrong type (" + msra::strfun::utf16(typeid(*get()).name()) + L"), expected a " + TypeId<C>());
return *p;
}
template<class C>
shared_ptr<C> AsPtr() const // returns a shared_ptr cast to the 'value' member
{
EnsureIsResolved();
const auto p = dynamic_pointer_cast<C>(*this);
if (!p) // TODO: can we make this look the same as TypeExpected in BrainScriptEvaluator.cpp? We'd need the type name
Fail(L"config member has wrong type (" + msra::strfun::utf16(typeid(*get()).name()) + L"), expected a " + TypeId<C>());
return p;
}
// --- properties
const char * TypeName() const { return typeid(*get()).name(); }
const wstring & GetExpressionName() const{ return expressionName; }
// TODO: ^^ it seems by saving the name in the ConfigValuePtr itself, we don't gain anything; maybe remove again in the future
// --- methods for resolving the value
const ConfigValuePtr & ResolveValue() const // (this is const but mutates the value if it resolves)
{
// call this when a a member might be as-of-yet unresolved, to evaluate it on-demand
// get() is a pointer to a Thunk in that case, that is, a function object that yields the value
const auto thunkp = GetThunk(); // is it a Thunk?
if (thunkp) // value is a Thunk: we need to resolve
{
const auto value = thunkp->ResolveValue(); // completely replace ourselves with the actual result. This also releases the Thunk object
const_cast<ConfigValuePtr&>(*this) = value;
ResolveValue(); // allow it to return another Thunk...
}
return *this; // return ourselves so we can access a value as p_resolved = p->ResolveValue()
}
void EnsureIsResolved() const
{
if (GetThunk())
Microsoft::MSR::CNTK::LogicError("ConfigValuePtr: unexpected access to unresolved object; ConfigValuePtrs can only be accessed after resolution");
}
}; // ConfigValuePtr
// use this for primitive values, double and bool
template<typename T> static inline ConfigValuePtr MakePrimitiveConfigValuePtr(const T & val, const function<void(const wstring &)> & failfn, const wstring & exprPath)
{
return ConfigValuePtr(make_shared<BoxOf<Wrapped<T>>>(val), failfn, exprPath);
}
// -----------------------------------------------------------------------
// IConfigRecord -- config record
// Inside BrainScript, this would be a BS::ConfigRecord, but outside of the
// evaluator, we will only pass it through this interface, to allow for
// extensibility (e.g. Python interfacing).
// Also, Objects themselves can expose this interface to make something accessible.
// -----------------------------------------------------------------------
struct IConfigRecord // any class that exposes config can derive from this
{
virtual const ConfigValuePtr & operator[](const wstring & id) const = 0; // e.g. confRec[L"message"]
virtual const ConfigValuePtr * Find(const wstring & id) const = 0; // returns nullptr if not found
virtual vector<wstring> GetMemberIds() const = 0; // returns the names of all members in this record (but not including parent scopes)
};
typedef shared_ptr<struct IConfigRecord> IConfigRecordPtr;
// -----------------------------------------------------------------------
// ConfigRecord -- collection of named config values
// -----------------------------------------------------------------------
class ConfigRecord : public Object, public IConfigRecord // all configuration arguments to class construction, resolved into ConfigValuePtrs
{
function<void(const wstring &)> failfn; // function to call in case of failure due to this value
// change to ContextInsensitiveMap<ConfigValuePtr>
map<wstring, ConfigValuePtr> members;
IConfigRecordPtr parentScope; // we look up the chain
ConfigRecord() { } // forbidden (private) to instantiate without a scope
public:
// --- creation phase
ConfigRecord(IConfigRecordPtr parentScope, const function<void(const wstring &)> & failfn) : parentScope(parentScope), failfn(failfn) { }
void Add(const wstring & id, const function<void(const wstring &)> & failfn, const ConfigValuePtr & value) { members[id] = value; failfn; }
void Add(const wstring & id, const function<void(const wstring &)> & failfn, ConfigValuePtr && value) { members[id] = move(value); failfn; } // use this for unresolved ConfigPtrs
// TODO: Add() does not yet correctly handle the failfn. It is meant to flag the location of the variable identifier
// --- usage phase
// regular lookup: just use record[id] or record(id, L"helpful message what 'id' does")
// Any unresolved value is resolved at this time, as it is being consumed. Only after resolving a ConfigValuePtr, it can be copied.
const ConfigValuePtr & /*IConfigRecord::*/operator[](const wstring & id) const // e.g. confRec[L"name"]
{
const auto memberIter = members.find(id);
if (memberIter != members.end())
return memberIter->second.ResolveValue(); // resolve upon access
if (!parentScope) // not found: if at top scope, we fail
failfn(L"required parameter '" + id + L"' not found");
// The failfn will report the location where the dictionary itself was formed.
// This is because this function is meant to be used by C++ code.
// When we look up a name by a BrainScript ".FIELD" expression, we will use Find() so we can report the error for the offending FIELD itself.
return (*parentScope)[id]; // have parent: look it up there
}
const ConfigValuePtr * /*IConfigRecord::*/Find(const wstring & id) const // returns nullptr if not found
{
auto memberIter = members.find(id);
if (memberIter == members.end())
if (parentScope)
return parentScope->Find(id);
else
return nullptr;
else
return &memberIter->second.ResolveValue();
}
// get member ids; use this when you intend to consume all record entries and do not know the names
// Note that unlike Find() and operator[], which return parent matches, this only returns entries in this record.
virtual vector<wstring> /*IConfigRecord::*/GetMemberIds() const
{
vector<wstring> ids;
for (auto & member : members)
ids.push_back(member.first);
return ids;
}
};
typedef shared_ptr<ConfigRecord> ConfigRecordPtr;
// TODO: can ConfigRecordPtr be IConfigRecordPtr?
// create a runtime object from its type --general case
// There can be specializations of this that instantiate objects that do not take ConfigRecords or involve mapping like ComputationNode.
template<typename C>
shared_ptr<Object> MakeRuntimeObject(const IConfigRecordPtr config)
{
return make_shared<C>(config);
}
// -----------------------------------------------------------------------
// ConfigArray -- an array of config values
// -----------------------------------------------------------------------
// an array is just a vector of config values
class ConfigArray : public Object
{
vector<ConfigValuePtr> values;
int firstIndex;
public:
ConfigArray() : firstIndex(0) { }
ConfigArray(int firstIndex, vector<ConfigValuePtr> && values) : firstIndex(firstIndex), values(move(values)) { }
pair<int, int> GetIndexRange() const { return make_pair(firstIndex, firstIndex+(int)values.size()-1); }
// building the array from expressions: append an element or an array
void Append(ConfigValuePtr value) { values.push_back(value); }
void Append(const ConfigArray & other) { values.insert(values.end(), other.values.begin(), other.values.end()); }
// get element at index, including bounds check
template<typename FAILFN>
const ConfigValuePtr & At(int index, const FAILFN & failfn/*should report location of the index*/) const
{
if (index < firstIndex || index >= firstIndex + values.size())
failfn(L"index out of bounds");
return values[(size_t)(index - firstIndex)].ResolveValue(); // resolve upon access
}
};
typedef shared_ptr<ConfigArray> ConfigArrayPtr;
// -----------------------------------------------------------------------
// ConfigLambda -- a lambda
// -----------------------------------------------------------------------
class ConfigLambda : public Object
{
public:
typedef map<wstring, ConfigValuePtr> NamedParams; // TODO: maybe even not use a typedef, just use the type
private:
// the function itself is a C++ lambda
function<ConfigValuePtr(vector<ConfigValuePtr> &&, NamedParams &&, const wstring & exprName)> f;
// inputs. This defines the interface to the function. Very simple in our case though.
// We pass rvalue references because that allows to pass Thunks.
vector<wstring> paramNames; // #parameters and parameter names (names are used for naming expressions only)
NamedParams namedParams; // lists named parameters with their default values. Named parameters are optional and thus always must have a default.
public:
template<typename F>
ConfigLambda(vector<wstring> && paramNames, NamedParams && namedParams, const F & f) : paramNames(move(paramNames)), namedParams(move(namedParams)), f(f) { }
size_t GetNumParams() const { return paramNames.size(); }
const vector<wstring> & GetParamNames() const { return paramNames; } // used for expression naming
// what this function does is call f() held in this object with the given arguments except optional arguments are verified and fall back to their defaults if not given
// The arguments are rvalue references, which allows us to pass Thunks, which is important to allow stuff with circular references like CNTK's DelayedNode.
ConfigValuePtr Apply(vector<ConfigValuePtr> && args, NamedParams && namedArgs, const wstring & exprName)
{
NamedParams actualNamedArgs;
// actualNamedArgs is a filtered version of namedArgs that contains all optional args listed in namedParams,
// falling back to their default if not given in namedArgs.
// On the other hand, any name in namedArgs that is not found in namedParams should be rejected.
for (const auto & namedParam : namedParams)
{
const auto & id = namedParam.first; // id of expected named parameter
const auto valuei = namedArgs.find(id); // was such parameter passed?
if (valuei == namedArgs.end()) // named parameter not passed
{ // if not given then fall back to default
auto f = [&namedParam]() // we pass a lambda that resolves it upon first use, in our original location
{
return namedParam.second.ResolveValue();
};
actualNamedArgs[id] = move(ConfigValuePtr::MakeThunk(f, namedParam.second.GetFailFn(), exprName));
}
else // named parameter was passed
actualNamedArgs[id] = move(valuei->second); // move it, possibly remaining unresolved
// BUGBUG: we should pass in the location of the identifier, not that of the expression
}
for (const auto & namedArg : namedArgs) // make sure there are no extra named args that the macro does not take
if (namedParams.find(namedArg.first) == namedParams.end())
namedArg.second.Fail(L"function does not have an optional argument '" + namedArg.first + L"'");
return f(move(args), move(actualNamedArgs), exprName);
}
// TODO: define an overload that takes const & for external users (which will then take a copy and pass it on to Apply &&)
};
typedef shared_ptr<ConfigLambda> ConfigLambdaPtr;
// -----------------------------------------------------------------------
// ConfigurableRuntimeType -- interface to scriptable runtime types
// -----------------------------------------------------------------------
// helper for configurableRuntimeTypes initializer below
// This returns a ConfigurableRuntimeType info structure that consists of
// - a lambda that is a constructor for a given runtime type and
// - a bool saying whether T derives from IConfigRecord
struct ConfigurableRuntimeType // TODO: rename to ScriptableObjects::Factory or something like that
{
bool isConfigRecord; // exposes IConfigRecord --in this case the expression name is computed differently, namely relative to this item
// TODO: is this ^^ actually still used anywhere?
function<shared_ptr<Object>(const IConfigRecordPtr)> construct; // lambda to construct an object of this class
// TODO: we should pass the expression name to construct() as well
};
// scriptable runtime types must be exposed by this function
// TODO: should this be a static member of above class?
const ConfigurableRuntimeType * FindExternalRuntimeTypeInfo(const wstring & typeId);
}}} // end namespaces

Просмотреть файл

@ -355,75 +355,55 @@ public:
// understood. Also, braces in strings are not protected. [fseide]
static std::string::size_type FindBraces(const std::string& str, std::string::size_type tokenStart)
{
// open braces and quote
static const std::string openBraces = OPENBRACES;
// close braces and quote
static const std::string closingBraces = CLOSINGBRACES;
const auto len = str.length();
// start is outside (or rather, at end of string): no brace here
if (tokenStart >= len) {
return npos;
}
auto braceFound = openBraces.find(str[tokenStart]);
// open braces and quote
static const std::string openBraces = OPENBRACES;
// close braces and quote
static const std::string closingBraces = CLOSINGBRACES;
const auto charsToLookFor = closingBraces + openBraces; // all chars we match for
// get brace index for first character of input string
const auto braceFound = openBraces.find(str[tokenStart]);
// no brace present at tokenStart
if (braceFound == npos) {
if (braceFound == npos)
return npos;
}
// string begins with a brace--find the closing brace, while correctly handling nested braces
std::vector<std::string::size_type> bracesFound;
std::string::size_type current, opening;
current = opening = tokenStart;
// create a brace pair for string searches
std::string braces;
braces += openBraces[braceFound];
braces += closingBraces[braceFound];
std::string braceStack; // nesting stack; .back() is closing symbol for inner-most brace
braceStack.push_back(closingBraces[braceFound]); // closing symbol for current
// search for end brace or other nested layers of this brace type
while (current != npos && current + 1 < len)
for (auto current = tokenStart; current + 1 < len;)
{
current = str.find_first_of(braces, current + 1);
// check for a nested opening brace
if (current == npos)
{
// look for closing brace and also for another opening brace
// Inside strings we only accept the closing quote, and ignore any braces inside.
current = str.find_first_of(braceStack.back() == '"' ? "\"" : charsToLookFor, current + 1); //
if (current == string::npos) // none found: done or error
break;
}
// found a closing brace
if (str[current] == braces[1])
char brace = str[current];
// found the expected closing brace?
if (brace == braceStack.back())
{
// no braces on the stack, we are done
if (bracesFound.empty())
{
braceStack.pop_back(); // yes: pop up and continue (or stop if stack is empty)
if (braceStack.empty()) // fully closed: done
return current;
}
// have braces on the stack, pop the current one off
opening = bracesFound.back();
bracesFound.pop_back();
}
// or any other closing brace? That's an error.
else if (closingBraces.find(brace) != string::npos)
RuntimeError("unmatched bracket found in parameters");
// found another opening brace, push it on the stack
else
{
// found another opening brace, push it on the stack
bracesFound.push_back(opening);
opening = current;
const auto braceFound = openBraces.find(brace); // index of brace
braceStack.push_back(closingBraces[braceFound]); // closing symbol for current
}
}
// if we found unmatched parenthesis, throw an exception
if (opening != npos)
{
RuntimeError("unmatched bracket found in parameters");
}
return current;
// hit end before everything was closed: error
RuntimeError("no closing bracket found in parameters");
}
// ParseValue - virtual function to parse a "token" as tokenized by Parse() below.
@ -981,7 +961,7 @@ public:
// ensure that this method was called on a single line (eg, no newline characters exist in 'configLine').
if (configLine.find_first_of("\n") != std::string::npos)
{
throw std::logic_error(
LogicError(
"\"ResolveVariablesInSingleLine\" shouldn't be called with a string containing a newline character");
}
@ -1028,7 +1008,7 @@ public:
if (varValue.find_first_of("\n") != std::string::npos)
{
throw std::logic_error(
LogicError(
"Newline character cannot be contained in the value of a variable which is resolved using $varName$ feature");
}

Просмотреть файл

@ -1545,7 +1545,7 @@ static BOOL ExpandWildcards (wstring path, vector<wstring> & paths)
return FALSE; // another error
}
size_t pos = path.find_last_of (L"\\");
if (pos == wstring::npos) throw std::logic_error ("unexpected missing \\ in path");
if (pos == wstring::npos) LogicError ("unexpected missing \\ in path");
wstring parent = path.substr (0, pos);
do
{

Просмотреть файл

@ -421,7 +421,7 @@ public:
size_t NumberSlicesInEachRecurrentIter() { return 1 ;}
void SetNbrSlicesEachRecurrentIter(const size_t) { };
void SetSentenceSegBatch(Matrix<ElemType> &/*sentenceBegin*/, vector<MinibatchPackingFlag>& /*sentenceExistsBeginOrNoLabels*/) {};
void SetSentenceSegBatch(Matrix<float> &/*sentenceBegin*/, vector<MinibatchPackingFlag>& /*sentenceExistsBeginOrNoLabels*/) {};
virtual const std::map<LabelIdType, LabelType>& GetLabelMapping(const std::wstring& sectionName);
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<typename BinaryReader<ElemType>::LabelIdType, typename BinaryReader<ElemType>::LabelType>& labelMapping);

Просмотреть файл

@ -72,7 +72,7 @@
<Link>
<SubSystem>Windows</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
<AdditionalDependencies>CNTKMath.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
<AdditionalDependencies>CNTKMathDll.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
<AdditionalLibraryDirectories>$(SolutionDir)$(Platform)\$(Configuration)\;..\..\math\$(Platform)\$(Configuration);..\$(Platform)\$(Configuration)</AdditionalLibraryDirectories>
</Link>
</ItemDefinitionGroup>
@ -95,7 +95,7 @@
<GenerateDebugInformation>true</GenerateDebugInformation>
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
<AdditionalDependencies>CNTKmath.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
<AdditionalDependencies>CNTKMathDll.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
<AdditionalLibraryDirectories>..\..\math\$(Platform)\$(Configuration);$(SolutionDir)$(Platform)\$(Configuration)\</AdditionalLibraryDirectories>
<Profile>true</Profile>
</Link>

Просмотреть файл

@ -143,7 +143,7 @@ public:
size_t NumberSlicesInEachRecurrentIter() { return 1 ;}
void SetNbrSlicesEachRecurrentIter(const size_t) { };
void SetSentenceSegBatch(Matrix<ElemType> &/*sentenceBegin*/, vector<MinibatchPackingFlag>& /*sentenceExistsBeginOrNoLabels*/) {};
void SetSentenceSegBatch(Matrix<float> &/*sentenceBegin*/, vector<MinibatchPackingFlag>& /*sentenceExistsBeginOrNoLabels*/) {};
virtual const std::map<LabelIdType, LabelType>& GetLabelMapping(const std::wstring& sectionName);
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, typename LabelType>& labelMapping);

Просмотреть файл

@ -74,7 +74,7 @@
<Link>
<SubSystem>Windows</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
<AdditionalDependencies>CNTKMath.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
<AdditionalDependencies>CNTKMathDll.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
<AdditionalLibraryDirectories>$(SolutionDir)$(Platform)\$(Configuration)\;..\..\math\$(Platform)\$(Configuration);..\$(Platform)\$(Configuration)</AdditionalLibraryDirectories>
</Link>
</ItemDefinitionGroup>
@ -97,7 +97,7 @@
<GenerateDebugInformation>true</GenerateDebugInformation>
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
<AdditionalDependencies>CNTKmath.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
<AdditionalDependencies>CNTKMathDll.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
<AdditionalLibraryDirectories>..\..\math\$(Platform)\$(Configuration);$(SolutionDir)$(Platform)\$(Configuration)\</AdditionalLibraryDirectories>
<Profile>true</Profile>
</Link>

Просмотреть файл

@ -10,17 +10,22 @@
#include "basetypes.h"
#include "htkfeatio.h" // for reading HTK features
#ifdef _WIN32
#include "latticearchive.h" // for reading HTK phoneme lattices (MMI training)
#endif
#include "simplesenonehmm.h" // for MMI scoring
#include "msra_mgram.h" // for unigram scores of ground-truth path in sequence training
#include "rollingwindowsource.h" // minibatch sources
#include "utterancesource.h"
#ifdef _WIN32
#include "readaheadsource.h"
#endif
#include "chunkevalsource.h"
#define DATAREADER_EXPORTS
#include "DataReader.h"
#include "HTKMLFReader.h"
#include "commandArgUtil.h"
namespace Microsoft { namespace MSR { namespace CNTK {
@ -38,6 +43,7 @@ extern "C" DATAREADER_API void GetReaderD(IDataReader<double>** preader)
{
GetReader(preader);
}
#ifdef _WIN32
// Utility function, in ConfigFile.cpp, but HTKMLFReader doesn't need that code...
// Trim - trim white space off the start and end of the string
@ -56,6 +62,7 @@ void Trim(std::string& str)
if (found != npos)
str.erase(found+1);
}
#endif
}}}

Просмотреть файл

@ -99,7 +99,7 @@ bool DataWriter<ElemType>::SaveData(size_t recordStart, const std::map<std::wstr
// saveId - name of the section to save into (section:subsection format)
// labelMapping - map we are saving to the file
template<class ElemType>
void DataWriter<ElemType>::SaveMapping(std::wstring saveId, const std::map<typename LabelIdType, typename LabelType>& labelMapping)
void DataWriter<ElemType>::SaveMapping(std::wstring saveId, const std::map<LabelIdType, LabelType>& labelMapping)
{
m_dataWriter->SaveMapping(saveId, labelMapping);
}

Просмотреть файл

@ -7,7 +7,9 @@
//
#include "stdafx.h"
#ifdef _WIN32
#include <objbase.h>
#endif
#include "basetypes.h"
#include "htkfeatio.h" // for reading HTK features
@ -19,19 +21,34 @@
#include "utterancesourcemulti.h"
#include "utterancesource.h"
#include "utterancesourcemulti.h"
#ifdef _WIN32
#include "readaheadsource.h"
#endif
#include "chunkevalsource.h"
#include "minibatchiterator.h"
#define DATAREADER_EXPORTS // creating the exports here
#include "DataReader.h"
#include "commandArgUtil.h"
#include "HTKMLFReader.h"
#ifdef LEAKDETECT
#include <vld.h> // for memory leak detection
#endif
#ifdef __unix__
#include <limits.h>
typedef unsigned long DWORD;
typedef unsigned short WORD;
typedef unsigned int UNINT32;
#endif
#pragma warning (disable: 4127) // conditional expression is constant; "if (sizeof(ElemType)==sizeof(float))" triggers this
#ifdef _WIN32
int msra::numa::node_override = -1; // for numahelpers.h
#endif
namespace msra { namespace lm {
/*static*/ const mgram_map::index_t mgram_map::nindex = (mgram_map::index_t) -1; // invalid index
}}
namespace Microsoft { namespace MSR { namespace CNTK {
@ -44,7 +61,9 @@ namespace Microsoft { namespace MSR { namespace CNTK {
m_cudaAllocator = nullptr;
m_mbiter = NULL;
m_frameSource = NULL;
#ifdef _WIN32
m_readAheadSource = NULL;
#endif
m_lattices = NULL;
m_truncated = readerConfig("Truncated", "false");
@ -202,7 +221,22 @@ namespace Microsoft { namespace MSR { namespace CNTK {
m_labelNameToIdMap[labelNames[i]]=iLabel;
m_labelNameToDimMap[labelNames[i]]=m_labelDims[i];
mlfpaths.clear();
mlfpaths.push_back(thisLabel("mlfFile"));
if (thisLabel.ExistsCurrent("mlfFile"))
{
mlfpaths.push_back(thisLabel("mlfFile"));
}
else
{
if (!thisLabel.ExistsCurrent("mlfFileList"))
{
RuntimeError("Either mlfFile or mlfFileList must exist in HTKMLFReder");
}
wstring list = thisLabel("mlfFileList");
for (msra::files::textreader r(list); r;)
{
mlfpaths.push_back(r.wgetline());
}
}
mlfpathsmulti.push_back(mlfpaths);
m_labelsBufferMultiIO.push_back(nullptr);
@ -271,7 +305,9 @@ namespace Microsoft { namespace MSR { namespace CNTK {
}
// see if they want to use readAhead
#ifdef _WIN32
m_readAhead = readerConfig("readAhead", "false");
#endif
// read all input files (from multiple inputs)
// TO DO: check for consistency (same number of files in each script file)
@ -319,7 +355,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
}
if (readerConfig.Exists("unigram"))
unigrampath = readerConfig("unigram");
unigrampath = (wstring)readerConfig("unigram");
// load a unigram if needed (this is used for MMI training)
msra::lm::CSymbolSet unigramsymbols;
@ -361,13 +397,13 @@ namespace Microsoft { namespace MSR { namespace CNTK {
//std::vector<std::wstring> pagepath;
foreach_index(i, mlfpathsmulti)
{
const msra::lm::CSymbolSet* wordmap = unigram ? &unigramsymbols : NULL;
msra::asr::htkmlfreader<msra::asr::htkmlfentry,msra::lattices::lattice::htkmlfwordsequence>
labels(mlfpathsmulti[i], restrictmlftokeys, statelistpaths[i], unigram ? &unigramsymbols : NULL, (map<string,size_t>*) NULL, htktimetoframe); // label MLF
labels(mlfpathsmulti[i], restrictmlftokeys, statelistpaths[i], wordmap, (map<string,size_t>*) NULL, htktimetoframe); // label MLF
// get the temp file name for the page file
labelsmulti.push_back(labels);
}
if (!_stricmp(readMethod.c_str(),"blockRandomize"))
{
// construct all the parameters we don't need, but need to be passed to the constructor...
@ -383,7 +419,11 @@ namespace Microsoft { namespace MSR { namespace CNTK {
}
else if (!_stricmp(readMethod.c_str(),"rollingWindow"))
{
#ifdef _WIN32
std::wstring pageFilePath;
#else
std::string pageFilePath;
#endif
std::vector<std::wstring> pagePaths;
if (readerConfig.Exists("pageFilePath"))
{
@ -391,28 +431,57 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// replace any '/' with '\' for compat with default path
std::replace(pageFilePath.begin(), pageFilePath.end(), '/','\\');
#ifdef _WIN32
// verify path exists
DWORD attrib = GetFileAttributes(pageFilePath.c_str());
if (attrib==INVALID_FILE_ATTRIBUTES || !(attrib & FILE_ATTRIBUTE_DIRECTORY))
throw std::runtime_error ("pageFilePath does not exist");
#endif
#ifdef __unix__
struct stat statbuf;
if (stat(pageFilePath.c_str(), &statbuf)==-1)
{
throw std::runtime_error ("pageFilePath does not exist");
}
#endif
}
else // using default temporary path
{
#ifdef _WIN32
pageFilePath.reserve(MAX_PATH);
GetTempPath(MAX_PATH, &pageFilePath[0]);
#endif
#ifdef __unix__
pageFilePath.reserve(PATH_MAX);
pageFilePath = "/tmp/temp.CNTK.XXXXXX";
#endif
}
#ifdef _WIN32
if (pageFilePath.size()>MAX_PATH-14) // max length of input to GetTempFileName is MAX_PATH-14
throw std::runtime_error (msra::strfun::strprintf ("pageFilePath must be less than %d characters", MAX_PATH-14));
#endif
#ifdef __unix__
if (pageFilePath.size()>PATH_MAX-14) // max length of input to GetTempFileName is PATH_MAX-14
throw std::runtime_error (msra::strfun::strprintf ("pageFilePath must be less than %d characters", PATH_MAX-14));
#endif
foreach_index(i, infilesmulti)
{
#ifdef _WIN32
wchar_t tempFile[MAX_PATH];
GetTempFileName(pageFilePath.c_str(), L"CNTK", 0, tempFile);
pagePaths.push_back(tempFile);
#endif
#ifdef __unix__
char* tempFile;
//GetTempFileName(pageFilePath.c_str(), L"CNTK", 0, tempFile);
tempFile = (char*) pageFilePath.c_str();
int fid = mkstemp(tempFile);
unlink (tempFile);
close (fid);
pagePaths.push_back(GetWC(tempFile));
#endif
}
const bool mayhavenoframe=false;
@ -513,7 +582,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
n++;
}
fprintf (stderr, " %d entries\n", n);
fprintf (stderr, " %d entries\n", (int)n);
if (i==0)
numFiles=n;
@ -534,7 +603,9 @@ namespace Microsoft { namespace MSR { namespace CNTK {
HTKMLFReader<ElemType>::~HTKMLFReader()
{
delete m_mbiter;
#ifdef _WIN32
delete m_readAheadSource;
#endif
delete m_frameSource;
delete m_lattices;
@ -587,7 +658,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
void HTKMLFReader<ElemType>::StartDistributedMinibatchLoop(size_t mbSize, size_t epoch, size_t subsetNum, size_t numSubsets, size_t requestedEpochSamples /*= requestDataSize*/)
{
assert(subsetNum < numSubsets);
assert(this->SupportsDistributedMBRead() || ((subsetNum == 0) && (numSubsets == 1)));
assert(((subsetNum == 0) && (numSubsets == 1)) || this->SupportsDistributedMBRead());
m_mbSize = mbSize;
@ -664,6 +735,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// delete the old one first (in case called more than once)
delete m_mbiter;
msra::dbn::minibatchsource* source = m_frameSource;
#ifdef _WIN32
if (m_readAhead)
{
if (m_readAheadSource == NULL)
@ -677,6 +749,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
}
source = m_readAheadSource;
}
#endif
m_mbiter = new msra::dbn::minibatchiterator(*source, epoch, requestedEpochSamples, mbSize, subsetNum, numSubsets, datapasses);
if (!m_featuresBufferMultiIO.empty())
{
@ -789,7 +862,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// now, access all features and and labels by iterating over map of "matrices"
bool first = true;
std::map<std::wstring, Matrix<ElemType>*>::iterator iter;
typename std::map<std::wstring, Matrix<ElemType>*>::iterator iter;
for (iter = matrices.begin();iter!=matrices.end(); iter++)
{
// dereference matrix that corresponds to key (input/output name) and
@ -810,9 +883,10 @@ namespace Microsoft { namespace MSR { namespace CNTK {
m_sentenceBegin.SetValue((ElemType) SEQUENCE_MIDDLE);
m_sentenceBegin.SetValue(0, 0, (ElemType) SEQUENCE_START);
m_sentenceBegin.SetValue(0, (size_t)feat.cols()-1, (ElemType) SEQUENCE_END);
std::fill(m_minibatchPackingFlag.begin(), m_minibatchPackingFlag.end(), MinibatchPackingFlag::None);
m_minibatchPackingFlag[0] = MinibatchPackingFlag::SequenceStart;
m_minibatchPackingFlag[(size_t)feat.cols() - 1] = MinibatchPackingFlag::SequenceEnd;
first = false;
}
@ -969,6 +1043,11 @@ namespace Microsoft { namespace MSR { namespace CNTK {
{
m_sentenceEnd[i] = false;
m_switchFrame[i] = m_mbSize+1;
if (m_processedFrame[i] == 1)
{
m_sentenceBegin.SetValue(i, 0, (ElemType)SEQUENCE_END);
m_minibatchPackingFlag[0] = MinibatchPackingFlag::SequenceEnd;
}
}
else
{
@ -979,7 +1058,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
}
actualmbsize[i] = m_mbSize;
endFr = startFr + actualmbsize[i];
std::map<std::wstring, Matrix<ElemType>*>::iterator iter;
typename std::map<std::wstring, Matrix<ElemType>*>::iterator iter;
for (iter = matrices.begin();iter!=matrices.end(); iter++)
{
// dereference matrix that corresponds to key (input/output name) and
@ -1044,7 +1123,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
actualmbsize[i] = m_toProcess[i] - m_processedFrame[i];
endFr = startFr + actualmbsize[i];
std::map<std::wstring, Matrix<ElemType>*>::iterator iter;
typename std::map<std::wstring, Matrix<ElemType>*>::iterator iter;
for (iter = matrices.begin();iter!=matrices.end(); iter++)
{
// dereference matrix that corresponds to key (input/output name) and
@ -1108,6 +1187,11 @@ namespace Microsoft { namespace MSR { namespace CNTK {
m_sentenceBegin.SetValue(i, actualmbsize[i], (ElemType)SEQUENCE_START);
m_minibatchPackingFlag[actualmbsize[i]] |= MinibatchPackingFlag::SequenceStart;
}
if (actualmbsize[i] == m_mbSize)
{
m_sentenceBegin.SetValue(i, actualmbsize[i]-1, (ElemType)SEQUENCE_END);
m_minibatchPackingFlag[actualmbsize[i]-1] |= MinibatchPackingFlag::SequenceEnd;
}
startFr = m_switchFrame[i];
endFr = m_mbSize;
bool reNewSucc = ReNewBufferForMultiIO(i);
@ -1158,7 +1242,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
}
}
std::map<std::wstring, Matrix<ElemType>*>::iterator iter;
typename std::map<std::wstring, Matrix<ElemType>*>::iterator iter;
for (iter = matrices.begin();iter!=matrices.end(); iter++)
{
// dereference matrix that corresponds to key (input/output name) and
@ -1195,7 +1279,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
{
if (matrices.find(iter->first)==matrices.end())
{
fprintf(stderr,"GetMinibatchToWrite: feature node %ws specified in reader not found in the network\n",iter->first.c_str());
fprintf(stderr,"GetMinibatchToWrite: feature node %ls specified in reader not found in the network\n", iter->first.c_str());
throw std::runtime_error("GetMinibatchToWrite: feature node specified in reader not found in the network.");
}
}
@ -1227,7 +1311,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
{
reader.read (path, featkind, sampperiod, feat); // whole file read as columns of feature vectors
});
fprintf (stderr, "evaluate: reading %d frames of %S\n", feat.cols(), ((wstring)path).c_str());
fprintf (stderr, "evaluate: reading %d frames of %S\n", (int)feat.cols(), ((wstring)path).c_str());
m_fileEvalSource->AddFile(feat, featkind, sampperiod, i);
}
m_inputFileIndex++;
@ -1237,7 +1321,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// populate input matrices
bool first = true;
std::map<std::wstring, Matrix<ElemType>*>::iterator iter;
typename std::map<std::wstring, Matrix<ElemType>*>::iterator iter;
for (iter = matrices.begin();iter!=matrices.end(); iter++)
{
// dereference matrix that corresponds to key (input/output name) and
@ -1256,9 +1340,10 @@ namespace Microsoft { namespace MSR { namespace CNTK {
m_minibatchPackingFlag.resize((size_t)feat.cols());
m_sentenceBegin.SetValue((ElemType)SEQUENCE_MIDDLE);
m_sentenceBegin.SetValue(0, 0, (ElemType)SEQUENCE_START);
m_sentenceBegin.SetValue(0, (size_t)feat.cols() - 1, (ElemType)SEQUENCE_END);
std::fill(m_minibatchPackingFlag.begin(), m_minibatchPackingFlag.end(), MinibatchPackingFlag::None);
m_minibatchPackingFlag[0] = MinibatchPackingFlag::SequenceStart;
m_minibatchPackingFlag[(size_t)feat.cols() - 1] = MinibatchPackingFlag::SequenceEnd;
first = false;
}
@ -1557,7 +1642,17 @@ namespace Microsoft { namespace MSR { namespace CNTK {
}
template<class ElemType>
void HTKMLFReader<ElemType>::SetSentenceSegBatch(Matrix<ElemType> &sentenceBegin, vector<MinibatchPackingFlag>& minibatchPackingFlag)
void HTKMLFReader<ElemType>::SetSentenceEndInBatch(vector<size_t> &sentenceEnd)
{
sentenceEnd.resize(m_switchFrame.size());
for (size_t i = 0; i < m_switchFrame.size() ; i++)
{
sentenceEnd[i] = m_switchFrame[i];
}
}
template<class ElemType>
void HTKMLFReader<ElemType>::SetSentenceSegBatch(Matrix<float> &sentenceBegin, vector<MinibatchPackingFlag>& minibatchPackingFlag)
{
if (!m_framemode)
{
@ -1582,7 +1677,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
{
features.push_back(msra::strfun::utf16(iter->first));
}
else if (temp.ExistsCurrent("mlfFile"))
else if (temp.ExistsCurrent("mlfFile")|| temp.ExistsCurrent("mlfFileList"))
{
labels.push_back(msra::strfun::utf16(iter->first));
}

Просмотреть файл

@ -21,7 +21,9 @@ private:
msra::dbn::minibatchiterator* m_mbiter;
msra::dbn::minibatchsource* m_frameSource;
#ifdef _WIN32
msra::dbn::minibatchreadaheadsource* m_readAheadSource;
#endif
msra::dbn::FileEvalSource* m_fileEvalSource;
msra::dbn::latticesource* m_lattices;
map<wstring,msra::lattices::lattice::htkmlfwordsequence> m_latticeMap;
@ -39,6 +41,8 @@ private:
vector<size_t> m_switchFrame;
bool m_noData;
bool m_trainOrTest; // if false, in file writing mode
using LabelType = typename IDataReader<ElemType>::LabelType;
using LabelIdType = typename IDataReader<ElemType>::LabelIdType;
std::map<LabelIdType, LabelType> m_idToLabelMap;
@ -141,9 +145,29 @@ private:
}
public:
Matrix<ElemType> m_sentenceBegin;
/// a matrix of n_stream x n_length
/// n_stream is the number of streams
/// n_length is the maximum lenght of each stream
/// for example, two sentences used in parallel in one minibatch would be
/// [2 x 5] if the max length of one of the sentences is 5
/// the elements of the matrix is 0, 1, or -1, defined as SEQUENCE_START, SEQUENCE_MIDDLE, NO_INPUT in cbasetype.h
/// 0 1 1 0 1
/// 1 0 1 0 0
/// for two parallel data streams. The first has two sentences, with 0 indicating begining of a sentence
/// the second data stream has two sentences, with 0 indicating begining of sentences
/// you may use 1 even if a sentence begins at that position, in this case, the trainer will carry over hidden states to the following
/// frame.
Matrix<float> m_sentenceBegin;
/// a matrix of 1 x n_length
/// 1 denotes the case that there exists sentnece begin or no_labels case in this frame
/// 0 denotes such case is not in this frame
vector<MinibatchPackingFlag> m_minibatchPackingFlag;
/// by default it is false
/// if true, reader will set to SEQUENCE_MIDDLE for time positions that are orignally correspond to SEQUENCE_START
/// set to true so that a current minibatch can uses state activities from the previous minibatch.
/// default will have truncated BPTT, which only does BPTT inside a minibatch
bool mIgnoreSentenceBeginTag;
HTKMLFReader() : m_sentenceBegin(CPUDEVICE) {
}
@ -158,18 +182,19 @@ public:
virtual bool SupportsDistributedMBRead() const override
{
return m_frameSource->supportsbatchsubsetting();
return ((m_frameSource != nullptr) && m_frameSource->supportsbatchsubsetting());
}
virtual void StartDistributedMinibatchLoop(size_t mbSize, size_t epoch, size_t subsetNum, size_t numSubsets, size_t requestedEpochSamples = requestDataSize) override;
virtual bool GetMinibatch(std::map<std::wstring, Matrix<ElemType>*>& matrices);
virtual const std::map<LabelIdType, LabelType>& GetLabelMapping(const std::wstring& sectionName);
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<unsigned, LabelType>& labelMapping);
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, LabelType>& labelMapping);
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart=0);
virtual bool DataEnd(EndDataType endDataType);
void SetSentenceSegBatch(Matrix<ElemType> &sentenceBegin, vector<MinibatchPackingFlag>& sentenceExistsBeginOrNoInputs);
void SetSentenceSegBatch(Matrix<float> &sentenceBegin, vector<MinibatchPackingFlag>& sentenceExistsBeginOrNoLabels);
void SetSentenceEndInBatch(vector<size_t> &/*sentenceEnd*/);
void SetSentenceEnd(int /*actualMbSize*/){};
void SetRandomSeed(int){ NOT_IMPLEMENTED };

Просмотреть файл

@ -71,7 +71,7 @@
<Link>
<SubSystem>Windows</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
<AdditionalDependencies>CNTKMath.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
<AdditionalDependencies>CNTKMathDll.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
@ -91,7 +91,7 @@
<GenerateDebugInformation>true</GenerateDebugInformation>
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
<AdditionalDependencies>CNTKMath.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
<AdditionalDependencies>CNTKMathDll.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
<Profile>true</Profile>
</Link>
</ItemDefinitionGroup>

Просмотреть файл

@ -7,7 +7,9 @@
//
#include "stdafx.h"
#ifdef _WIN32
#include <objbase.h>
#endif
#include "basetypes.h"
#include "htkfeatio.h" // for reading HTK features
@ -85,7 +87,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
n++;
}
fprintf (stderr, " %d entries\n", n);
fprintf (stderr, " %d entries\n", (int)n);
if (i==0)
numFiles=n;
@ -163,17 +165,17 @@ namespace Microsoft { namespace MSR { namespace CNTK {
msra::files::make_intermediate_dirs (outputFile);
msra::util::attempt (5, [&]()
{
msra::asr::htkfeatwriter::write (outputFile, "USER", sampPeriod, output);
msra::asr::htkfeatwriter::write (outputFile, "USER", this->sampPeriod, output);
});
fprintf (stderr, "evaluate: writing %d frames of %S\n", output.cols(), outputFile.c_str());
fprintf (stderr, "evaluate: writing %d frames of %S\n", (int)output.cols(), outputFile.c_str());
}
template<class ElemType>
void HTKMLFWriter<ElemType>::SaveMapping(std::wstring saveId, const std::map<typename LabelIdType, typename LabelType>& /*labelMapping*/)
void HTKMLFWriter<ElemType>::SaveMapping(std::wstring saveId, const std::map<LabelIdType, LabelType>& /*labelMapping*/)
{
}

Просмотреть файл

@ -6,6 +6,8 @@
// HTKMLFReader.h - Include file for the MTK and MLF format of features and samples
#pragma once
#include "DataWriter.h"
#include <map>
#include <vector>
namespace Microsoft { namespace MSR { namespace CNTK {
@ -33,11 +35,13 @@ private:
};
public:
using LabelType = typename IDataWriter<ElemType>::LabelType;
using LabelIdType = typename IDataWriter<ElemType>::LabelIdType;
virtual void Init(const ConfigParameters& writerConfig);
virtual void Destroy();
virtual void GetSections(std::map<std::wstring, SectionType, nocase_compare>& sections);
virtual bool SaveData(size_t recordStart, const std::map<std::wstring, void*, nocase_compare>& matrices, size_t numRecords, size_t datasetSize, size_t byteVariableSized);
virtual void SaveMapping(std::wstring saveId, const std::map<typename LabelIdType, typename LabelType>& labelMapping);
virtual void SaveMapping(std::wstring saveId, const std::map<LabelIdType, LabelType>& labelMapping);
};
}}}

Просмотреть файл

@ -82,18 +82,51 @@ OACR_WARNING_DISABLE(POTENTIAL_ARGUMENT_TYPE_MISMATCH, "Not level1 or level2_sec
#pragma warning(disable : 4702) // unreachable code
#endif
#include "Platform.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h> // include here because we redefine some names later
#include <errno.h>
#include <string>
#include <vector>
#include <cmath> // for HUGE_VAL
#include <assert.h>
#include <stdarg.h>
#include <map>
#include <windows.h> // for CRITICAL_SECTION
#include <stdexcept>
#include <locale> // std::wstring_convert
#include <string>
#include <algorithm> // for transform()
#include <mutex>
#include <unordered_map>
#include <chrono>
#include <thread>
#ifdef _MSC_VER
#include <codecvt> // std::codecvt_utf8
#endif
#ifdef _WIN32
#include <windows.h> // for CRITICAL_SECTION and Unicode conversion functions --TODO: is there a portable alternative?
#endif
#if __unix__
#include <strings.h>
#include <unistd.h>
#include <sys/stat.h>
#include <dlfcn.h>
#include <sys/time.h>
typedef unsigned char byte;
#endif
#ifdef _WIN32
#pragma push_macro("STRSAFE_NO_DEPRECATE")
#define STRSAFE_NO_DEPRECATE // deprecation managed elsewhere, not by strsafe
#include <strsafe.h> // for strbcpy() etc templates
#pragma pop_macro("STRSAFE_NO_DEPRECATE")
#endif
using namespace std;
// CRT error handling seems to not be included in wince headers
// so we define our own imports
@ -106,6 +139,7 @@ OACR_WARNING_DISABLE(POTENTIAL_ARGUMENT_TYPE_MISMATCH, "Not level1 or level2_sec
#define strerror(x) "strerror error but can't report error number sorry!"
#endif
#ifdef _WIN32
#ifndef __in // dummies for sal annotations if compiler does not support it
#define __in
#define __inout_z
@ -122,11 +156,103 @@ OACR_WARNING_DISABLE(POTENTIAL_ARGUMENT_TYPE_MISMATCH, "Not level1 or level2_sec
#ifndef __override // and some more non-std extensions required by Office
#define __override virtual
#endif
#endif
// disable warnings for which fixing would make code less readable
#pragma warning(disable : 4290) // throw() declaration ignored
#pragma warning(disable : 4244) // conversion from typeA to typeB, possible loss of data
// ----------------------------------------------------------------------------
// (w)cstring -- helper class like std::string but with auto-cast to char*
// ----------------------------------------------------------------------------
namespace msra { namespace strfun {
// a class that can return a std::string with auto-convert into a const char*
template<typename C> struct basic_cstring : public std::basic_string<C>
{
template<typename S> basic_cstring (S p) : std::basic_string<C> (p) { }
operator const C * () const { return this->c_str(); }
};
typedef basic_cstring<char> cstring;
typedef basic_cstring<wchar_t> wcstring;
}}
static inline wchar_t*GetWC(const char *c)
{
const size_t cSize = strlen(c)+1;
wchar_t* wc = new wchar_t[cSize];
mbstowcs (wc, c, cSize);
return wc;
}
struct MatchPathSeparator
{
bool operator()( char ch ) const
{
return ch == '\\' || ch == '/';
}
};
static inline std::string basename( std::string const& pathname)
{
return std::string (std::find_if(pathname.rbegin(), pathname.rend(),MatchPathSeparator()).base(), pathname.end());
}
static inline std::string removeExtension (std::string const& filename)
{
//std::string::const_reverse_iterator pivot = std::find(filename.rbegin(), filename.rend(), '.');
//return pivot == filename.rend() ? filename: std::string(filename.begin(), pivot.base()-1);
size_t lastindex = filename.find_last_of(".");
return filename.substr(0, lastindex);
}
static inline std::wstring basename( std::wstring const& pathname)
{
return std::wstring (std::find_if(pathname.rbegin(), pathname.rend(),MatchPathSeparator()).base(), pathname.end());
}
static inline std::wstring removeExtension (std::wstring const& filename)
{
//std::wstring::const_reverse_iterator pivot = std::find(filename.rbegin(), filename.rend(), '.');
//return pivot == filename.rend() ? filename: std::wstring(filename.begin(), pivot.base()-1);
size_t lastindex = filename.find_last_of(L".");
return filename.substr(0, lastindex);
}
// ----------------------------------------------------------------------------
// some mappings for non-Windows builds
// ----------------------------------------------------------------------------
#ifndef _MSC_VER // add some functions that are VS-only
// --- basic file functions
// convert a wchar_t path to what gets passed to CRT functions that take narrow characters
// This is needed for the Linux CRT which does not accept wide-char strings for pathnames anywhere.
// Always use this function for mapping the paths.
static inline msra::strfun::cstring charpath (const std::wstring & p)
{
#ifdef _WIN32
return std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>>().to_bytes(p);
#else // old version, delete once we know it works
size_t len = p.length();
std::vector<char> buf(2 * len + 1, 0); // max: 1 wchar => 2 mb chars
::wcstombs(buf.data(), p.c_str(), 2 * len + 1);
return msra::strfun::cstring (&buf[0]);
#endif
}
static inline FILE* _wfopen (const wchar_t * path, const wchar_t * mode) { return fopen(charpath(path), charpath(mode)); }
static inline int _wunlink (const wchar_t * p) { return unlink (charpath (p)); }
static inline int _wmkdir (const wchar_t * p) { return mkdir (charpath (p), 0777/*correct?*/); }
// --- basic string functions
static inline wchar_t* wcstok_s (wchar_t* s, const wchar_t* delim, wchar_t** ptr) { return ::wcstok(s, delim, ptr); }
static inline int _stricmp (const char * a, const char * b) { return ::strcasecmp (a, b); }
static inline int _strnicmp (const char * a, const char * b, size_t n) { return ::strncasecmp (a, b, n); }
static inline int _wcsicmp (const wchar_t * a, const wchar_t * b) { return ::wcscasecmp (a, b); }
static inline int _wcsnicmp (const wchar_t * a, const wchar_t * b, size_t n) { return ::wcsncasecmp (a, b, n); }
static inline int64_t _strtoi64 (const char * s, char ** ep, int r) { return strtoll (s, ep, r); } // TODO: check if correct
static inline uint64_t _strtoui64 (const char * s, char ** ep, int r) { return strtoull (s, ep, r); } // TODO: correct for size_t?
// -- other
//static inline void memcpy_s(void * dst, size_t dstsize, const void * src, size_t maxcount) { assert (maxcount <= dstsize); memcpy (dst, src, maxcount); }
static inline void Sleep (size_t ms) { std::this_thread::sleep_for (std::chrono::milliseconds (ms)); }
#define _countof(_Array) (sizeof(_Array) / sizeof(_Array[0]))
#endif
// ----------------------------------------------------------------------------
// basic macros
// ----------------------------------------------------------------------------
@ -142,6 +268,9 @@ extern void _CHECKED_ASSERT_error(const char * file, int line, const char * exp)
#endif
#endif
#define EPSILON 1e-5
#define ISCLOSE(a, b, threshold) (abs(a - b) < threshold)?true:false
/**
These macros are used for sentence segmentation information.
*/
@ -190,6 +319,8 @@ namespace msra { namespace basetypes {
// class ARRAY -- std::vector with array-bounds checking
// VS 2008 and above do this, so there is no longer a need for this.
#pragma warning(push)
#pragma warning(disable : 4555) // expression has no affect, used so retail won't be empty
template<class _ElemType>
class ARRAY : public std::vector<_ElemType>
@ -296,6 +427,7 @@ public:
};
template<class _T> inline void swap (fixed_vector<_T> & L, fixed_vector<_T> & R) throw() { L.swap (R); }
#pragma warning(pop) // pop off waring: expression has no effect
// class matrix - simple fixed-size 2-dimensional array, access elements as m(i,j)
// stored as concatenation of rows
@ -307,14 +439,14 @@ public:
typedef T elemtype;
matrix() : numcols (0) {}
matrix (size_t n, size_t m) { resize (n, m); }
void resize (size_t n, size_t m) { numcols = m; fixed_vector::resize (n * m); }
void resize (size_t n, size_t m) { numcols = m; fixed_vector<T>::resize (n * m); }
size_t cols() const { return numcols; }
size_t rows() const { return empty() ? 0 : size() / cols(); }
size_t size() const { return fixed_vector::size(); } // use this for reading and writing... not nice!
bool empty() const { return fixed_vector::empty(); }
size_t size() const { return fixed_vector<T>::size(); } // use this for reading and writing... not nice!
bool empty() const { return fixed_vector<T>::empty(); }
T & operator() (size_t i, size_t j) { return (*this)[locate(i,j)]; }
const T & operator() (size_t i, size_t j) const { return (*this)[locate(i,j)]; }
void swap (matrix & other) throw() { std::swap (numcols, other.numcols); fixed_vector::swap (other); }
void swap (matrix & other) throw() { std::swap (numcols, other.numcols); fixed_vector<T>::swap (other); }
};
template<class _T> inline void swap (matrix<_T> & L, matrix<_T> & R) throw() { L.swap (R); }
@ -333,16 +465,16 @@ public:
noncopyable(){}
};
// class CCritSec and CAutoLock -- simple critical section handling
class CCritSec
{
CCritSec (const CCritSec &); CCritSec & operator= (const CCritSec &);
CRITICAL_SECTION m_CritSec;
CCritSec (const CCritSec &) = delete;
CCritSec & operator= (const CCritSec &) = delete;
std::mutex m_CritSec;
public:
CCritSec() { InitializeCriticalSection(&m_CritSec); };
~CCritSec() { DeleteCriticalSection(&m_CritSec); };
void Lock() { EnterCriticalSection(&m_CritSec); };
void Unlock() { LeaveCriticalSection(&m_CritSec); };
CCritSec() {};
~CCritSec() {};
void Lock() { m_CritSec.lock(); };
void Unlock() { m_CritSec.unlock(); };
};
// locks a critical section, and unlocks it automatically
@ -356,6 +488,7 @@ public:
~CAutoLock() { m_rLock.Unlock(); };
};
#ifdef _WIN32
// an efficient way to write COM code
// usage examples:
// COM_function() || throw_hr ("message");
@ -436,9 +569,11 @@ public:
operator void * () { return TlsGetValue (tlsSlot); }
void *operator = (void *val) { if (!TlsSetValue (tlsSlot,val)) throw std::runtime_error ("tls: TlsSetValue failed"); return val; }
};
#endif
};}; // namespace
#ifdef _WIN32
#ifndef BASETYPES_NO_UNSAFECRTOVERLOAD // if on, no unsafe CRT overload functions
// ----------------------------------------------------------------------------
@ -465,7 +600,11 @@ public:
#include <xlocale> // uses strlen()
#endif
#define strlen strlen_
#ifndef LINUX
template<typename _T> inline __declspec(deprecated("Dummy general template, cannot be used directly"))
#else
template<typename _T> inline
#endif // LINUX
size_t strlen_(_T &s) { return strnlen_s(static_cast<const char *>(s), SIZE_MAX); } // never be called but needed to keep compiler happy
template<typename _T> inline size_t strlen_(const _T &s) { return strnlen_s(static_cast<const char *>(s), SIZE_MAX); }
template<> inline size_t strlen_(char * &s) { return strnlen_s(s, SIZE_MAX); }
@ -544,7 +683,10 @@ static inline const char *strerror_(int e)
if (msgs.find(e) == msgs.end()) { char msg[1024]; strerror_s (msg, e); msgs[e] = msg; }
return msgs[e].c_str();
}
#endif
#endif
#ifdef __unix__
extern int fileno(FILE*); // somehow got deprecated in C++11
#endif
// ----------------------------------------------------------------------------
@ -560,8 +702,11 @@ template<class _T> struct _strprintf : public std::basic_string<_T>
{ // works for both wchar_t* and char*
_strprintf (const _T * format, ...)
{
va_list args; va_start (args, format); // varargs stuff
va_list args;
va_start (args, format); // varargs stuff
size_t n = _cprintf (format, args); // num chars excl. '\0'
va_end(args);
va_start(args, format);
const int FIXBUF_SIZE = 128; // incl. '\0'
if (n < FIXBUF_SIZE)
{
@ -576,16 +721,47 @@ template<class _T> struct _strprintf : public std::basic_string<_T>
}
private:
// helpers
inline size_t _cprintf (const wchar_t * format, va_list args) { return _vscwprintf (format, args); }
inline size_t _cprintf (const char * format, va_list args) { return _vscprintf (format, args); }
inline const wchar_t * _sprintf (wchar_t * buf, size_t bufsiz, const wchar_t * format, va_list args) { vswprintf_s (buf, bufsiz, format, args); return buf; }
inline const char * _sprintf ( char * buf, size_t bufsiz, const char * format, va_list args) { vsprintf_s (buf, bufsiz, format, args); return buf; }
inline size_t _cprintf(const wchar_t* format, va_list args)
{
#ifdef __WINDOWS__
return vswprintf(nullptr, 0, format, args);
#elif defined(__UNIX__)
// TODO: Really??? Write to file in order to know the length of a string?
FILE *dummyf = fopen("/dev/null", "w");
if (dummyf == NULL)
perror("The following error occurred in basetypes.h:cprintf");
int n = vfwprintf (dummyf, format, args);
if (n < 0)
perror("The following error occurred in basetypes.h:cprintf");
fclose(dummyf);
return n;
#endif
}
inline size_t _cprintf(const char* format, va_list args)
{
#ifdef __WINDOWS__
return vsprintf(nullptr, format, args);
#elif defined(__UNIX__)
// TODO: Really??? Write to file in order to know the length of a string?
FILE *dummyf = fopen("/dev/null", "wb");
if (dummyf == NULL)
perror("The following error occurred in basetypes.h:cprintf");
int n = vfprintf (dummyf, format, args);
if (n < 0)
perror("The following error occurred in basetypes.h:cprintf");
fclose(dummyf);
return n;
#endif
}
inline const wchar_t * _sprintf(wchar_t * buf, size_t bufsiz, const wchar_t * format, va_list args) { vswprintf(buf, bufsiz, format, args); return buf; }
inline const char * _sprintf ( char * buf, size_t /*bufsiz*/, const char * format, va_list args) { vsprintf (buf, format, args); return buf; }
};
typedef strfun::_strprintf<char> strprintf; // char version
typedef strfun::_strprintf<wchar_t> wstrprintf; // wchar_t version
#endif
#ifdef _WIN32
// string-encoding conversion functions
struct utf8 : std::string { utf8 (const std::wstring & p) // utf-16 to -8
{
@ -612,6 +788,7 @@ struct utf16 : std::wstring { utf16 (const std::string & p) // utf-8 to -16
ASSERT (rc < buf.size ());
(*(std::wstring*)this) = &buf[0];
}};
#endif
#pragma warning(push)
#pragma warning(disable : 4996) // Reviewed by Yusheng Li, March 14, 2006. depr. fn (wcstombs, mbstowcs)
@ -633,6 +810,19 @@ static inline std::wstring mbstowcs (const std::string & p) // input: MBCS
return std::wstring (&buf[0]);
}
#pragma warning(pop)
#ifdef _WIN32
static inline cstring utf8 (const std::wstring & p) { return std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>>().to_bytes(p); } // utf-16 to -8
static inline wcstring utf16 (const std::string & p) { return std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>>().from_bytes(p); } // utf-8 to -16
#else // BUGBUG: we cannot compile the above on Cygwin GCC, so for now fake it using the mbs functions, which will only work for 7-bit ASCII strings
static inline std::string utf8 (const std::wstring & p) { return msra::strfun::wcstombs (p.c_str()); } // output: UTF-8... not really
static inline std::wstring utf16 (const std::string & p) { return msra::strfun::mbstowcs(p.c_str()); } // input: UTF-8... not really
#endif
static inline cstring utf8 (const std::string & p) { return p; } // no conversion (useful in templated functions)
static inline wcstring utf16 (const std::wstring & p) { return p; }
// convert a string to lowercase --TODO: currently only correct for 7-bit ASCII
template<typename CHAR>
static inline void tolower_ascii (std::basic_string<CHAR> & s) { std::transform(s.begin(), s.end(), s.begin(), [] (CHAR c) { return (c >= 0 && c < 128) ? ::tolower(c) : c; }); }
// split and join -- tokenize a string like strtok() would, join() strings together
template<class _T> static inline std::vector<std::basic_string<_T>> split (const std::basic_string<_T> & s, const _T * delim)
@ -662,7 +852,11 @@ template<class _T> static inline std::basic_string<_T> join (const std::vector<s
// parsing strings to numbers
static inline int toint (const wchar_t * s)
{
#ifdef _WIN32
return _wtoi (s); // ... TODO: check it
#else
return (int)wcstol(s, 0, 10);
#endif
}
static inline int toint (const char * s)
{
@ -766,7 +960,7 @@ public:
auto_file_ptr() : f (NULL) { }
~auto_file_ptr() { close(); }
auto_file_ptr (const char * path, const char * mode) { f = fopen (path, mode); if (f == NULL) openfailed (path); }
auto_file_ptr (const wchar_t * path, const char * mode) { f = _wfopen (path, msra::strfun::utf16 (mode).c_str()); if (f == NULL) openfailed (msra::strfun::utf8 (path)); }
auto_file_ptr (const wchar_t * wpath, const char * mode) { f = _wfopen (wpath, msra::strfun::utf16 (mode).c_str()); if (f == NULL) openfailed (msra::strfun::utf8 (wpath)); }
FILE * operator= (FILE * other) { close(); f = other; return f; }
auto_file_ptr (FILE * other) : f (other) { }
operator FILE * () const { return f; }
@ -775,6 +969,7 @@ public:
};
inline int fclose (auto_file_ptr & af) { return af.fclose(); }
#ifdef _MSC_VER
// auto-closing container for Win32 handles.
// Pass close function if not CloseHandle(), e.g.
// auto_handle h (FindFirstFile(...), FindClose);
@ -791,6 +986,7 @@ public:
operator _H () const { return h; }
};
typedef auto_handle_t<HANDLE> auto_handle;
#endif
// like auto_ptr but calls freeFunc_p (type free_func_t) instead of delete to clean up
// minor difference - wrapped object is T, not T *, so to wrap a
@ -814,6 +1010,9 @@ public:
// simple timer
// auto_timer timer; run(); double seconds = timer; // now can abandon the object
#ifdef __unix__
typedef timeval LARGE_INTEGER;
#endif
class auto_timer
{
LARGE_INTEGER freq, start;
@ -821,15 +1020,26 @@ class auto_timer
public:
auto_timer()
{
#ifdef _WIN32
if (!QueryPerformanceFrequency (&freq)) // count ticks per second
throw std::runtime_error ("auto_timer: QueryPerformanceFrequency failure");
QueryPerformanceCounter (&start);
#endif
#ifdef __unix__
gettimeofday (&start, NULL);
#endif
}
operator double() const // each read gives time elapsed since start, in seconds
{
LARGE_INTEGER end;
#ifdef _WIN32
QueryPerformanceCounter (&end);
return (end.QuadPart - start.QuadPart) / (double) freq.QuadPart;
#endif
#ifdef __unix__
gettimeofday (&end,NULL);
return (end.tv_sec - start.tv_sec) + (end.tv_usec - start.tv_usec)/(1000*1000);
#endif
}
void show (const std::string & msg) const
{
@ -881,8 +1091,10 @@ public:
#define foreach_index(_i,_dat) for (int _i = 0; _i < (int) (_dat).size(); _i++)
#define map_array(_x,_expr,_y) { _y.resize (_x.size()); foreach_index(_i,_x) _y[_i]=_expr(_x[_i]); }
#define reduce_array(_x,_expr,_y) { foreach_index(_i,_x) _y = (_i==0) ? _x[_i] : _expr(_y,_x[_i]); }
#ifdef _WIN32
template<class _A,class _F>
static void fill_array(_A & a, _F v) { ::fill (a.begin(), a.end(), v); }
#endif
// ----------------------------------------------------------------------------
// frequently missing utility functions
@ -897,7 +1109,11 @@ namespace msra { namespace util {
class command_line
{
int num;
#ifdef _WIN32
(const wchar_t *) * args;
#else
const wchar_t ** args;
#endif
public:
command_line (int argc, wchar_t * argv[]) : num (argc), args ((const wchar_t **) argv) { shift(); }
inline int size() const { return num; }
@ -948,6 +1164,7 @@ template<typename FUNCTION> static void attempt (int retries, const FUNCTION & b
};}; // namespace
#ifdef _WIN32
// ----------------------------------------------------------------------------
// frequently missing Win32 functions
// ----------------------------------------------------------------------------
@ -988,6 +1205,7 @@ static inline LPWSTR CoTaskMemString (const wchar_t * s)
if (p) for (size_t i = 0; i < n; i++) p[i] = s[i];
return p;
}
#endif
template<class S> static inline void ZeroStruct (S & s) { memset (&s, 0, sizeof (s)); }
@ -1002,9 +1220,7 @@ using namespace msra::basetypes; // for compatibility
#pragma warning (pop)
// RuntimeError - throw a std::runtime_error with a formatted error string
#ifdef _MSC_VER
__declspec(noreturn)
#endif
__declspec_noreturn
static inline void RuntimeError(const char * format, ...)
{
va_list args;
@ -1016,9 +1232,7 @@ static inline void RuntimeError(const char * format, ...)
};
// LogicError - throw a std::logic_error with a formatted error string
#ifdef _MSC_VER
__declspec(noreturn)
#endif
__declspec_noreturn
static inline void LogicError(const char * format, ...)
{
va_list args;
@ -1047,7 +1261,7 @@ public:
m_dllName += L".dll";
m_hModule = LoadLibrary(m_dllName.c_str());
if (m_hModule == NULL)
RuntimeError("Plugin not found: %s", msra::strfun::utf8(m_dllName));
RuntimeError("Plugin not found: %s", msra::strfun::utf8(m_dllName).c_str());
// create a variable of each type just to call the proper templated version
return GetProcAddress(m_hModule, proc.c_str());
@ -1057,14 +1271,37 @@ public:
#else
class Plugin
{
private:
void *handle;
public:
Plugin()
{
handle = NULL;
}
template<class STRING> // accepts char (UTF-8) and wide string
void * Load(const STRING & plugin, const std::string & proc)
{
RuntimeError("Plugins not implemented on Linux yet");
return nullptr;
string soName = msra::strfun::utf8(plugin);
soName = soName + ".so";
void *handle = dlopen(soName.c_str(), RTLD_LAZY);
if (handle == NULL)
RuntimeError("Plugin not found: %s (error: %s)", soName.c_str(), dlerror());
return dlsym(handle, proc.c_str());
}
~Plugin()
{
if (handle != NULL)
dlclose(handle);
}
};
#endif
template<class F>
static inline bool comparator(const pair<int, F>& l, const pair<int, F>& r)
{
return l.second > r.second;
}
#endif // _BASETYPES_

Просмотреть файл

@ -92,30 +92,30 @@ public:
// ---------------------------------------------------------------------------
// biggrowablevector -- big vector we can push_back to
// ---------------------------------------------------------------------------
template<typename ELEMTYPE> class biggrowablevector : public growablevectorbase<std::vector<ELEMTYPE>>
template<class ELEMTYPE> class biggrowablevector : public growablevectorbase<std::vector<ELEMTYPE>>
{
public:
biggrowablevector() : growablevectorbase (65536) { }
biggrowablevector() : growablevectorbase<std::vector<ELEMTYPE>>::growablevectorbase (65536) { }
template<typename VALTYPE> void push_back (VALTYPE e) // VALTYPE could be an rvalue reference
{
size_t i = size();
resize_without_commit (i + 1);
auto & block = getblockptr (i);
size_t i = this->size();
this->resize_without_commit (i + 1);
auto & block = this->getblockptr (i);
if (block.get() == NULL)
block.reset (new std::vector<ELEMTYPE> (elementsperblock));
(*block)[getblockt (i)] = e;
block.reset (new std::vector<ELEMTYPE> (this->elementsperblock));
(*block)[this->getblockt (i)] = e;
}
ELEMTYPE & operator[] (size_t t) { return getblock(t)[getblockt (t)]; } // get an element
const ELEMTYPE & operator[] (size_t t) const { return getblock(t)[getblockt (t)]; } // get an element
ELEMTYPE & operator[] (size_t t) { return this->getblock(t)[this->getblockt (t)]; } // get an element
const ELEMTYPE & operator[] (size_t t) const { return this->getblock(t)[this->getblockt (t)]; } // get an element
void resize (const size_t n)
{
resize_without_commit (n);
foreach_index (i, blocks)
if (blocks[i].get() == NULL)
blocks[i].reset (new std::vector<ELEMTYPE> (elementsperblock));
this->resize_without_commit (n);
foreach_index (i, this->blocks)
if (this->blocks[i].get() == NULL)
this->blocks[i].reset (new std::vector<ELEMTYPE> (this->elementsperblock));
}
};

Просмотреть файл

@ -10,7 +10,9 @@
#include "basetypes.h" // for attempt()
#include "htkfeatio.h" // for reading HTK features
#include "minibatchsourcehelpers.h"
#include "ssematrix.h"
#ifndef __unix__
#include "ssematrix.h" // TODO: why can it not be removed for Windows as well? At least needs a comment here.
#endif
#ifdef LEAKDETECT
#include <vld.h> // for memory leak detection
@ -58,7 +60,7 @@ namespace msra { namespace dbn {
unsigned int sampperiod = sampperiods[k];
size_t n = numframes[k];
msra::files::make_intermediate_dirs (outfile);
fprintf (stderr, "saveandflush: writing %d frames to %S\n", n, outfile.c_str());
fprintf (stderr, "saveandflush: writing %d frames to %S\n", (int)n, outfile.c_str());
msra::dbn::matrixstripe thispred (pred, firstframe, n);
// some sanity check for the data we've written
const size_t nansinf = thispred.countnaninf();
@ -171,7 +173,7 @@ namespace msra { namespace dbn {
unsigned int sampperiod = sampperiods[index][k];
size_t n = numframes[k];
msra::files::make_intermediate_dirs (outfile);
fprintf (stderr, "saveandflush: writing %d frames to %S\n", n, outfile.c_str());
fprintf (stderr, "saveandflush: writing %d frames to %S\n", (int)n, outfile.c_str());
msra::dbn::matrixstripe thispred (pred, firstframe, n);
// some sanity check for the data we've written
const size_t nansinf = thispred.countnaninf();

Просмотреть файл

@ -245,7 +245,7 @@ void fflushOrDie (FILE * f)
// ----------------------------------------------------------------------------
size_t filesize (FILE * f)
{
#ifdef WIN32
#ifdef _WIN32
size_t curPos = _ftelli64 (f);
if (curPos == -1L)
{
@ -269,6 +269,27 @@ size_t filesize (FILE * f)
return len;
#else
// linux version
long curPos = ftell (f);
if (curPos == -1L)
{
RuntimeError ("error determining file position: %s", strerror (errno));
}
int rc = fseek (f, 0, SEEK_END);
if (rc != 0)
{
RuntimeError ("error seeking to end of file: %s", strerror (errno));
}
long len = ftell (f);
if (len == -1L)
{
RuntimeError ("error determining file position: %s", strerror (errno));
}
rc = fseek (f, curPos, SEEK_SET);
if (rc != 0)
{
RuntimeError ("error resetting file position: %s", strerror (errno));
}
return (size_t) len;
#endif
}

Просмотреть файл

@ -10,13 +10,28 @@
#ifndef _FILEUTIL_
#define _FILEUTIL_
#include "Platform.h"
#ifdef _WIN32
#include "basetypes.h"
#endif
#include <stdio.h>
#ifdef __WINDOWS__
#include <windows.h> // for mmreg.h and FILETIME
#include <mmreg.h>
#endif
#ifdef __unix__
#include <sys/types.h>
#include <sys/stat.h>
#endif
#include <algorithm> // for std::find
#include <vector>
#include <map>
#include <functional>
#include <cctype>
#include <errno.h>
#include <stdint.h>
#include <assert.h>
#include <string.h> // for strerror()
using namespace std;
#define SAFE_CLOSE(f) (((f) == NULL) || (fcloseOrDie ((f)), (f) = NULL))
@ -28,8 +43,8 @@ using namespace std;
// not to fclose() such a handle.
// ----------------------------------------------------------------------------
FILE * fopenOrDie (const STRING & pathname, const char * mode);
FILE * fopenOrDie (const WSTRING & pathname, const wchar_t * mode);
FILE * fopenOrDie (const string & pathname, const char * mode);
FILE * fopenOrDie (const wstring & pathname, const wchar_t * mode);
#ifndef __unix__ // don't need binary/text distinction on unix
// ----------------------------------------------------------------------------
@ -44,7 +59,9 @@ void fsetmode (FILE * f, char type);
// ----------------------------------------------------------------------------
void freadOrDie (void * ptr, size_t size, size_t count, FILE * f);
#ifdef _WIN32
void freadOrDie (void * ptr, size_t size, size_t count, const HANDLE f);
#endif
template<class _T>
void freadOrDie (_T & data, int num, FILE * f) // template for vector<>
@ -53,12 +70,14 @@ template<class _T>
void freadOrDie (_T & data, size_t num, FILE * f) // template for vector<>
{ data.resize (num); if (data.size() > 0) freadOrDie (&data[0], sizeof (data[0]), data.size(), f); }
#ifdef _WIN32
template<class _T>
void freadOrDie (_T & data, int num, const HANDLE f) // template for vector<>
{ data.resize (num); if (data.size() > 0) freadOrDie (&data[0], sizeof (data[0]), data.size(), f); }
template<class _T>
void freadOrDie (_T & data, size_t num, const HANDLE f) // template for vector<>
{ data.resize (num); if (data.size() > 0) freadOrDie (&data[0], sizeof (data[0]), data.size(), f); }
#endif
// ----------------------------------------------------------------------------
@ -66,15 +85,19 @@ void freadOrDie (_T & data, size_t num, const HANDLE f) // template for vecto
// ----------------------------------------------------------------------------
void fwriteOrDie (const void * ptr, size_t size, size_t count, FILE * f);
#ifdef _WIN32
void fwriteOrDie (const void * ptr, size_t size, size_t count, const HANDLE f);
#endif
template<class _T>
void fwriteOrDie (const _T & data, FILE * f) // template for vector<>
{ if (data.size() > 0) fwriteOrDie (&data[0], sizeof (data[0]), data.size(), f); }
#ifdef _WIN32
template<class _T>
void fwriteOrDie (const _T & data, const HANDLE f) // template for vector<>
{ if (data.size() > 0) fwriteOrDie (&data[0], sizeof (data[0]), data.size(), f); }
#endif
// ----------------------------------------------------------------------------
@ -111,6 +134,10 @@ int64_t filesize64 (const wchar_t * pathname);
// 32-bit offsets only
long fseekOrDie (FILE * f, long offset, int mode = SEEK_SET);
#define ftellOrDie ftell
// ----------------------------------------------------------------------------
// fget/setpos(): seek functions with error handling
// ----------------------------------------------------------------------------
uint64_t fgetpos (FILE * f);
void fsetpos (FILE * f, uint64_t pos);
@ -158,27 +185,6 @@ void fskipspace (FILE * F);
// fskipNewLine(): skip all white space until end of line incl. the newline
// ----------------------------------------------------------------------------
template<class CHAR> CHAR * fgetline (FILE * f, CHAR * buf, int size);
template<class CHAR, size_t n> CHAR * fgetline (FILE * f, CHAR (& buf)[n]) { return fgetline (f, buf, n); }
STRING fgetline (FILE * f);
WSTRING fgetlinew (FILE * f);
void fgetline (FILE * f, std::string & s, ARRAY<char> & buf);
void fgetline (FILE * f, std::wstring & s, ARRAY<char> & buf);
void fgetline (FILE * f, ARRAY<char> & buf);
void fgetline (FILE * f, ARRAY<wchar_t> & buf);
const char * fgetstring (FILE * f, char * buf, int size);
template<size_t n> const char * fgetstring (FILE * f, char (& buf)[n]) { return fgetstring (f, buf, n); }
const char * fgetstring (const HANDLE f, char * buf, int size);
template<size_t n> const char * fgetstring (const HANDLE f, char (& buf)[n]) { return fgetstring (f, buf, n); }
wstring fgetwstring (FILE * f);
const char * fgettoken (FILE * f, char * buf, int size);
template<size_t n> const char * fgettoken (FILE * f, char (& buf)[n]) { return fgettoken (f, buf, n); }
STRING fgettoken (FILE * f);
void fskipNewline (FILE * f);
// ----------------------------------------------------------------------------
// fputstring(): write a 0-terminated string (terminate if error)
// ----------------------------------------------------------------------------
@ -189,32 +195,75 @@ void fputstring (FILE * f, const std::string &);
void fputstring (FILE * f, const wchar_t *);
void fputstring (FILE * f, const std::wstring &);
template<class CHAR> CHAR * fgetline (FILE * f, CHAR * buf, int size);
template<class CHAR, size_t n> CHAR * fgetline (FILE * f, CHAR (& buf)[n]) { return fgetline (f, buf, n); }
string fgetline (FILE * f);
wstring fgetlinew (FILE * f);
void fgetline (FILE * f, std::string & s, std::vector<char> & buf);
void fgetline (FILE * f, std::wstring & s, std::vector<char> & buf);
void fgetline (FILE * f, std::vector<char> & buf);
void fgetline (FILE * f, std::vector<wchar_t> & buf);
const char * fgetstring (FILE * f, char * buf, int size);
template<size_t n> const char * fgetstring (FILE * f, char (& buf)[n]) { return fgetstring (f, buf, n); }
const char * fgetstring (const HANDLE f, char * buf, int size);
template<size_t n> const char * fgetstring (const HANDLE f, char (& buf)[n]) { return fgetstring (f, buf, n); }
const wchar_t * fgetstring (FILE * f, wchar_t * buf, int size);
wstring fgetwstring (FILE * f);
string fgetstring (FILE * f);
const char * fgettoken (FILE * f, char * buf, int size);
template<size_t n> const char * fgettoken (FILE * f, char (& buf)[n]) { return fgettoken (f, buf, n); }
string fgettoken (FILE * f);
const wchar_t * fgettoken (FILE * f, wchar_t * buf, int size);
wstring fgetwtoken (FILE * f);
int fskipNewline (FILE * f, bool skip = true);
int fskipwNewline (FILE * f, bool skip = true);
// ----------------------------------------------------------------------------
// fputstring(): write a 0-terminated string (terminate if error)
// ----------------------------------------------------------------------------
void fputstring (FILE * f, const char *);
#ifdef _WIN32
void fputstring (const HANDLE f, const char * str);
#endif
void fputstring (FILE * f, const std::string &);
void fputstring (FILE * f, const wchar_t *);
void fputstring (FILE * f, const std::wstring &);
// ----------------------------------------------------------------------------
// fgetTag(): read a 4-byte tag & return as a string
// ----------------------------------------------------------------------------
STRING fgetTag (FILE * f);
string fgetTag (FILE * f);
// ----------------------------------------------------------------------------
// fcheckTag(): read a 4-byte tag & verify it; terminate if wrong tag
// ----------------------------------------------------------------------------
void fcheckTag (FILE * f, const char * expectedTag);
#ifdef _WIN32
void fcheckTag (const HANDLE f, const char * expectedTag);
void fcheckTag_ascii (FILE * f, const STRING & expectedTag);
#endif
void fcheckTag_ascii (FILE * f, const string & expectedTag);
// ----------------------------------------------------------------------------
// fcompareTag(): compare two tags; terminate if wrong tag
// ----------------------------------------------------------------------------
void fcompareTag (const STRING & readTag, const STRING & expectedTag);
void fcompareTag (const string & readTag, const string & expectedTag);
// ----------------------------------------------------------------------------
// fputTag(): write a 4-byte tag
// ----------------------------------------------------------------------------
void fputTag (FILE * f, const char * tag);
#ifdef _WIN32
void fputTag(const HANDLE f, const char * tag);
#endif
// ----------------------------------------------------------------------------
// fskipstring(): skip a 0-terminated string, such as a pad string
@ -252,10 +301,17 @@ int fgetint24 (FILE * f);
// ----------------------------------------------------------------------------
int fgetint (FILE * f);
#ifdef _WIN32
int fgetint (const HANDLE f);
#endif
int fgetint_bigendian (FILE * f);
int fgetint_ascii (FILE * f);
// ----------------------------------------------------------------------------
// fgetlong(): read an long value
// ----------------------------------------------------------------------------
long fgetlong (FILE * f);
// ----------------------------------------------------------------------------
// fgetfloat(): read a float value
// ----------------------------------------------------------------------------
@ -270,6 +326,7 @@ float fgetfloat_ascii (FILE * f);
double fgetdouble (FILE * f);
#ifdef _WIN32
// ----------------------------------------------------------------------------
// fgetwav(): read an entire .wav file
// ----------------------------------------------------------------------------
@ -283,6 +340,7 @@ void fgetwav (const wstring & fn, ARRAY<short> & wav, int & sampleRate);
void fputwav (FILE * f, const vector<short> & wav, int sampleRate, int nChannels = 1);
void fputwav (const wstring & fn, const vector<short> & wav, int sampleRate, int nChannels = 1);
#endif
// ----------------------------------------------------------------------------
// fputbyte(): write a byte value
@ -307,7 +365,16 @@ void fputint24 (FILE * f, int v);
// ----------------------------------------------------------------------------
void fputint (FILE * f, int val);
// ----------------------------------------------------------------------------
// fputlong(): write an long value
// ----------------------------------------------------------------------------
void fputlong (FILE * f, long val);
#ifdef _WIN32
void fputint (const HANDLE f, int v);
#endif
// ----------------------------------------------------------------------------
// fputfloat(): write a float value
@ -320,27 +387,154 @@ void fputfloat (FILE * f, float val);
// ----------------------------------------------------------------------------
void fputdouble (FILE * f, double val);
// template versions of put/get functions for binary files
template <typename T>
void fput(FILE * f, T v)
{
fwriteOrDie (&v, sizeof (v), 1, f);
}
// template versions of put/get functions for binary files
template <typename T>
void fget(FILE * f, T& v)
{
freadOrDie ((void *)&v, sizeof (v), 1, f);
}
// GetFormatString - get the format string for a particular type
template <typename T>
const wchar_t* GetFormatString(T /*t*/)
{
// if this _ASSERT goes off it means that you are using a type that doesn't have
// a read and/or write routine.
// If the type is a user defined class, you need to create some global functions that handles file in/out.
// for example:
//File& operator>>(File& stream, MyClass& test);
//File& operator<<(File& stream, MyClass& test);
//
// in your class you will probably want to add these functions as friends so you can access any private members
// friend File& operator>>(File& stream, MyClass& test);
// friend File& operator<<(File& stream, MyClass& test);
//
// if you are using wchar_t* or char* types, these use other methods because they require buffers to be passed
// either use std::string and std::wstring, or use the WriteString() and ReadString() methods
assert(false); // need a specialization
return NULL;
}
// GetFormatString - specalizations to get the format string for a particular type
template <> const wchar_t* GetFormatString(char);
template <> const wchar_t* GetFormatString(wchar_t);
template <> const wchar_t* GetFormatString(short);
template <> const wchar_t* GetFormatString(int);
template <> const wchar_t* GetFormatString(long);
template <> const wchar_t* GetFormatString(unsigned short);
template <> const wchar_t* GetFormatString(unsigned int);
template <> const wchar_t* GetFormatString(unsigned long);
template <> const wchar_t* GetFormatString(float);
template <> const wchar_t* GetFormatString(double);
template <> const wchar_t* GetFormatString(size_t);
template <> const wchar_t* GetFormatString(long long);
template <> const wchar_t* GetFormatString(const char*);
template <> const wchar_t* GetFormatString(const wchar_t*);
// GetScanFormatString - get the format string for a particular type
template <typename T>
const wchar_t* GetScanFormatString(T t)
{
assert(false); // need a specialization
return NULL;
}
// GetScanFormatString - specalizations to get the format string for a particular type
template <> const wchar_t* GetScanFormatString(char);
template <> const wchar_t* GetScanFormatString(wchar_t);
template <> const wchar_t* GetScanFormatString(short);
template <> const wchar_t* GetScanFormatString(int);
template <> const wchar_t* GetScanFormatString(long);
template <> const wchar_t* GetScanFormatString(unsigned short);
template <> const wchar_t* GetScanFormatString(unsigned int);
template <> const wchar_t* GetScanFormatString(unsigned long);
template <> const wchar_t* GetScanFormatString(float);
template <> const wchar_t* GetScanFormatString(double);
template <> const wchar_t* GetScanFormatString(size_t);
template <> const wchar_t* GetScanFormatString(long long);
// ----------------------------------------------------------------------------
// fgetText(): get a value from a text file
// ----------------------------------------------------------------------------
template <typename T>
void fgetText(FILE * f, T& v)
{
int rc = ftrygetText(f, v);
if (rc == 0)
throw std::runtime_error("error reading value from file (invalid format)");
else if (rc == EOF)
throw std::runtime_error(std::string("error reading from file: ") + strerror(errno));
assert(rc == 1);
}
// version to try and get a string, and not throw exceptions if contents don't match
template <typename T>
int ftrygetText(FILE * f, T& v)
{
const wchar_t* formatString = GetScanFormatString<T>(v);
int rc = fwscanf (f, formatString, &v);
assert(rc == 1 || rc == 0);
return rc;
}
template <> int ftrygetText<bool>(FILE * f, bool& v);
// ----------------------------------------------------------------------------
// fgetText() specializations for fwscanf_s differences: get a value from a text file
// ----------------------------------------------------------------------------
void fgetText(FILE * f, char& v);
void fgetText(FILE * f, wchar_t& v);
// ----------------------------------------------------------------------------
// fputText(): write a value out as text
// ----------------------------------------------------------------------------
template <typename T>
void fputText(FILE * f, T v)
{
const wchar_t* formatString = GetFormatString(v);
int rc = fwprintf(f, formatString, v);
if (rc == 0)
throw std::runtime_error("error writing value to file, no values written");
else if (rc < 0)
throw std::runtime_error(std::string("error writing to file: ") + strerror(errno));
}
// ----------------------------------------------------------------------------
// fputText(): write a bool out as character
// ----------------------------------------------------------------------------
template <> void fputText<bool>(FILE * f, bool v);
// ----------------------------------------------------------------------------
// fputfile(): write a binary block or a string as a file
// ----------------------------------------------------------------------------
void fputfile (const WSTRING & pathname, const ARRAY<char> & buffer);
void fputfile (const WSTRING & pathname, const std::wstring & string);
void fputfile (const WSTRING & pathname, const std::string & string);
void fputfile (const wstring & pathname, const std::vector<char> & buffer);
void fputfile (const wstring & pathname, const std::wstring & string);
void fputfile (const wstring & pathname, const std::string & string);
// ----------------------------------------------------------------------------
// fgetfile(): load a file as a binary block
// ----------------------------------------------------------------------------
void fgetfile (const WSTRING & pathname, ARRAY<char> & buffer);
void fgetfile (FILE * f, ARRAY<char> & buffer);
void fgetfile (const wstring & pathname, std::vector<char> & buffer);
void fgetfile (FILE * f, std::vector<char> & buffer);
namespace msra { namespace files {
void fgetfilelines (const std::wstring & pathname, vector<char> & readbuffer, std::vector<std::string> & lines);
static inline std::vector<std::string> fgetfilelines (const std::wstring & pathname) { vector<char> buffer; std::vector<std::string> lines; fgetfilelines (pathname, buffer, lines); return lines; }
vector<char*> fgetfilelines (const wstring & pathname, vector<char> & readbuffer);
};};
#ifdef _WIN32
// ----------------------------------------------------------------------------
// getfiletime(), setfiletime(): access modification time
// ----------------------------------------------------------------------------
@ -348,6 +542,7 @@ namespace msra { namespace files {
bool getfiletime (const std::wstring & path, FILETIME & time);
void setfiletime (const std::wstring & path, const FILETIME & time);
#endif
// ----------------------------------------------------------------------------
// expand_wildcards() -- expand a path with wildcards (also intermediate ones)
// ----------------------------------------------------------------------------
@ -370,6 +565,7 @@ namespace msra { namespace files {
bool fuptodate (const wstring & target, const wstring & input, bool inputrequired = true);
};};
#ifdef _WIN32
// ----------------------------------------------------------------------------
// simple support for WAV file I/O
// ----------------------------------------------------------------------------
@ -408,7 +604,8 @@ void fputwfx (FILE *f, const WAVEFORMATEX & wfx, unsigned int numSamples);
// For example, data[i][j]: i is channel index, 0 means the first
// channel. j is sample index.
// ----------------------------------------------------------------------------
void fgetraw (FILE *f,ARRAY< ARRAY<short> > & data,const WAVEHEADER & wavhd);
void fgetraw (FILE *f,std::vector< std::vector<short> > & data,const WAVEHEADER & wavhd);
#endif
// ----------------------------------------------------------------------------
// temp functions -- clean these up
@ -445,4 +642,23 @@ static inline bool relpath (const wchar_t * path)
template<class CHAR>
static inline bool relpath (const std::basic_string<CHAR> & s) { return relpath (s.c_str()); }
// trim from start
static inline std::string &ltrim(std::string &s) {
s.erase(s.begin(), std::find_if(s.begin(), s.end(), std::not1(std::ptr_fun<int, int>(std::isspace))));
return s;
}
// trim from end
static inline std::string &rtrim(std::string &s) {
s.erase(std::find_if(s.rbegin(), s.rend(), std::not1(std::ptr_fun<int, int>(std::isspace))).base(), s.end());
return s;
}
// trim from both ends
static inline std::string &trim(std::string &s) {
return ltrim(rtrim(s));
}
vector<string> sep_string(const string & str, const string & sep);
#endif // _FILEUTIL_

Просмотреть файл

@ -14,8 +14,10 @@
#include <string>
#include <regex>
#include <set>
#include <hash_map>
#include <unordered_map>
#include <stdint.h>
#include <limits.h>
#include <wchar.h>
namespace msra { namespace asr {
@ -263,9 +265,11 @@ public:
#else
W.close (numframes);
#endif
#ifdef _WIN32 // BUGBUG: and on Linux??
// rename to final destination
// (This would only fail in strange circumstances such as accidental multiple processes writing to the same file.)
renameOrDie (tmppath, path);
#endif
}
};
@ -386,7 +390,7 @@ private:
{
wstring physpath = ppath.physicallocation();
//auto_file_ptr f = fopenOrDie (physpath, L"rbS");
auto_file_ptr f = fopenOrDie (physpath, L"rb"); // removed 'S' for now, as we mostly run local anyway, and this will speed up debugging
auto_file_ptr f(fopenOrDie (physpath, L"rb")); // removed 'S' for now, as we mostly run local anyway, and this will speed up debugging
// read the header (12 bytes for htk feature files)
fileheader H;
@ -655,7 +659,7 @@ private:
public:
// parse format with original HTK state align MLF format and state list
void parsewithstatelist (const vector<char*> & toks, const hash_map<const string, size_t> & statelisthash, const double htkTimeToFrame)
void parsewithstatelist (const vector<char*> & toks, const unordered_map<std::string, size_t> & statelisthash, const double htkTimeToFrame)
{
size_t ts, te;
parseframerange (toks, ts, te, htkTimeToFrame);
@ -682,7 +686,7 @@ template<class ENTRY, class WORDSEQUENCE>
class htkmlfreader : public map<wstring,vector<ENTRY>> // [key][i] the data
{
wstring curpath; // for error messages
hash_map<const std::string, size_t> statelistmap; // for state <=> index
unordered_map<std::string, size_t> statelistmap; // for state <=> index
map<wstring,WORDSEQUENCE> wordsequences; // [key] word sequences (if we are building word entries as well, for MMI)
void strtok (char * s, const char * delim, vector<char*> & toks)
@ -700,7 +704,7 @@ class htkmlfreader : public map<wstring,vector<ENTRY>> // [key][i] the data
vector<char*> readlines (const wstring & path, vector<char> & buffer)
{
// load it into RAM in one huge chunk
auto_file_ptr f = fopenOrDie (path, L"rb");
auto_file_ptr f(fopenOrDie (path, L"rb"));
size_t len = filesize (f);
buffer.reserve (len +1);
freadOrDie (buffer, len, f);
@ -752,7 +756,12 @@ class htkmlfreader : public map<wstring,vector<ENTRY>> // [key][i] the data
filename = filename.substr (1, filename.length() -2); // strip quotes
if (filename.find ("*/") == 0) filename = filename.substr (2);
#ifdef _WIN32
wstring key = msra::strfun::utf16 (regex_replace (filename, regex ("\\.[^\\.\\\\/:]*$"), string())); // delete extension (or not if none)
#endif
#ifdef __unix__
wstring key = msra::strfun::utf16 (removeExtension(basename(filename))); // note that c++ 4.8 is incomplete for supporting regex
#endif
// determine lines range
size_t s = line;
@ -785,7 +794,7 @@ class htkmlfreader : public map<wstring,vector<ENTRY>> // [key][i] the data
const char * w = toks[6]; // the word name
int wid = (*wordmap)[w]; // map to word id --may be -1 for unseen words in the transcript (word list typically comes from a test LM)
size_t wordindex = (wid == -1) ? WORDSEQUENCE::word::unknownwordindex : (size_t) wid;
wordseqbuffer.push_back (WORDSEQUENCE::word (wordindex, entries[i-s].firstframe, alignseqbuffer.size()));
wordseqbuffer.push_back (typename WORDSEQUENCE::word (wordindex, entries[i-s].firstframe, alignseqbuffer.size()));
}
if (unitmap)
{
@ -796,7 +805,7 @@ class htkmlfreader : public map<wstring,vector<ENTRY>> // [key][i] the data
if (iter == unitmap->end())
throw std::runtime_error (string ("parseentry: unknown unit ") + u + " in utterance " + strfun::utf8 (key));
const size_t uid = iter->second;
alignseqbuffer.push_back (WORDSEQUENCE::aligninfo (uid, 0/*#frames--we accumulate*/));
alignseqbuffer.push_back (typename WORDSEQUENCE::aligninfo (uid, 0/*#frames--we accumulate*/));
}
if (alignseqbuffer.empty())
throw std::runtime_error ("parseentry: lonely senone entry at start without phone/word entry found, for utterance " + strfun::utf8 (key));
@ -880,7 +889,7 @@ public:
template<typename WORDSYMBOLTABLE, typename UNITSYMBOLTABLE>
void read (const wstring & path, const set<wstring> & restricttokeys, const WORDSYMBOLTABLE * wordmap, const UNITSYMBOLTABLE * unitmap, const double htkTimeToFrame)
{
if (!restricttokeys.empty() && size() >= restricttokeys.size()) // no need to even read the file if we are there (we support multiple files)
if (!restricttokeys.empty() && this->size() >= restricttokeys.size()) // no need to even read the file if we are there (we support multiple files)
return;
fprintf (stderr, "htkmlfreader: reading MLF file %S ...", path.c_str());
@ -888,18 +897,18 @@ public:
vector<char> buffer; // buffer owns the characters--don't release until done
vector<char*> lines = readlines (path, buffer);
vector<WORDSEQUENCE::word> wordsequencebuffer;
vector<WORDSEQUENCE::aligninfo> alignsequencebuffer;
vector<typename WORDSEQUENCE::word> wordsequencebuffer;
vector<typename WORDSEQUENCE::aligninfo> alignsequencebuffer;
if (lines.empty() || strcmp (lines[0], "#!MLF!#")) malformed ("header missing");
// parse entries
size_t line = 1;
while (line < lines.size() && (restricttokeys.empty() || size() < restricttokeys.size()))
while (line < lines.size() && (restricttokeys.empty() || this->size() < restricttokeys.size()))
parseentry (lines, line, restricttokeys, wordmap, unitmap, wordsequencebuffer, alignsequencebuffer, htkTimeToFrame);
curpath.clear();
fprintf (stderr, " total %lu entries\n", size());
fprintf (stderr, " total %lu entries\n", this->size());
}
// read state list, index is from 0

Просмотреть файл

@ -18,7 +18,7 @@
#include <vector>
#include <string>
#include <set>
#include <hash_map>
#include <unordered_map>
#include <regex>
#pragma warning(disable : 4996)
@ -95,20 +95,6 @@ static size_t tryfind (const MAPTYPE & map, const KEYTYPE & key, VALTYPE deflt)
const msra::asr::htkmlfreader<msra::asr::htkmlfentry,msra::lattices::lattice::htkmlfwordsequence> & labels, // non-empty: build numer lattices
const msra::lm::CMGramLM & unigram, const msra::lm::CSymbolSet & unigramsymbols) // for numer lattices
{
#if 0 // little unit test helper for testing the read function
bool test = true;
if (test)
{
archive a;
a.open (outpath + L".toc");
lattice L;
std::hash_map<string,size_t> symmap;
a.getlattice (L"sw2001_A_1263622500_1374610000", L, symmap);
a.getlattice (L"sw2001_A_1391162500_1409287500", L, symmap);
return;
}
#endif
const bool numermode = !labels.empty(); // if labels are passed then we shall convert the MLFs to lattices, and 'infiles' are regular keys
const std::wstring tocpath = outpath + L".toc";

Просмотреть файл

@ -12,6 +12,8 @@
#undef HACK_IN_SILENCE // [v-hansu] hack to simulate DEL in the lattice
#define SILENCE_PENALTY // give penalty to added silence
#define __STDC_FORMAT_MACROS
#include <inttypes.h>
#include "basetypes.h"
#include "latticestorage.h"
@ -20,11 +22,9 @@
#include <stdint.h>
#include <vector>
#include <string>
#include <hash_map>
#include <unordered_map>
#include <algorithm> // for find()
#include "simplesenonehmm.h"
namespace msra { namespace math { class ssematrixbase; template<class ssematrixbase> class ssematrix; template<class ssematrixbase> class ssematrixstriperef; };};
namespace msra { namespace lm { class CMGramLM; class CSymbolSet; };}; // for numer-lattice building
@ -60,7 +60,7 @@ class lattice
size_t impliedspunitid : 31; // id of implied last unit (intended as /sp/); only used in V2
size_t hasacscores : 1; // if 1 then ac scores are embedded
header_v1_v2() : numnodes (0), numedges (0), lmf (1.0f), wp (0.0f), frameduration (0.01/*assumption*/), numframes (0), impliedspunitid (SIZE_MAX), hasacscores (1) { }
header_v1_v2() : numnodes (0), numedges (0), lmf (1.0f), wp (0.0f), frameduration (0.01/*assumption*/), numframes (0), impliedspunitid (INT_MAX), hasacscores (1) { }
};
header_v1_v2 info; // information about the lattice
static const unsigned int NOEDGE = 0xffffff; // 24 bits
@ -188,7 +188,7 @@ public: // TODO: make private again once
if (ai.size() < 2) // less than 2--must be /sil/
continue;
spunit = ai[ai.size() - 1].unit;
fprintf (stderr, "builduniquealignments: /sp/ unit inferred through heuristics as %d\n", spunit);
fprintf (stderr, "builduniquealignments: /sp/ unit inferred through heuristics as %d\n", (int)spunit);
break;
}
}
@ -235,7 +235,7 @@ public: // TODO: make private again once
&& nodes[edges[prevj].E].t == nodes[edges[j].E].t
&& edges[prevj].l != edges[j].l) // some diagnostics
fprintf (stderr, "build: merging edges %d and %d despite slightly different LM scores %.8f vs. %.8f, ts/te=%.2f/%.2f\n",
prevj, j, edges[prevj].l, edges[j].l, nodes[edges[prevj].S].t * 0.01f, nodes[edges[prevj].E].t * 0.01f);
(int)prevj, (int)j, edges[prevj].l, edges[j].l, nodes[edges[prevj].S].t * 0.01f, nodes[edges[prevj].E].t * 0.01f);
#endif
if (prevj == SIZE_MAX || fabs (edges[prevj].l - edges[j].l) > lmargin || (info.hasacscores && edges[prevj].a != edges[j].a) || comparealign (prevj, j, false) != 0)
{
@ -287,8 +287,8 @@ public: // TODO: make private again once
const size_t uniquealigntokens = uniquededgedatatokens.size() - (numuniquealignments * (info.hasacscores ? 2 : 1));
const size_t nonuniquenonsptokens = align.size() - numimpliedsp;
fprintf (stderr, "builduniquealignments: %d edges: %d unique alignments (%.2f%%); %d align tokens - %d implied /sp/ units = %d, uniqued to %d (%.2f%%)\n",
edges.size(), numuniquealignments, 100.0f * numuniquealignments / edges.size(),
align.size(), numimpliedsp, nonuniquenonsptokens, uniquealigntokens, 100.0f * uniquealigntokens / nonuniquenonsptokens);
(int)edges.size(), (int)numuniquealignments, 100.0f * numuniquealignments / edges.size(),
(int)align.size(), (int)numimpliedsp, (int)nonuniquenonsptokens, (int)uniquealigntokens, 100.0f * uniquealigntokens / nonuniquenonsptokens);
// sort it back into original order (sorted by E, then by S)
sort (edges2.begin(), edges2.end(), [&] (const edgeinfo & e1, const edgeinfo & e2) { return latticeorder (e1, e2) < 0; });
@ -507,7 +507,7 @@ public:
}
};
typedef aligninfo aligninfo; // now we can access it as htkmlfwordsequence::aligninfo although it comes from some totally other corner of the system
typedef msra::lattices::aligninfo aligninfo; // now we can access it as htkmlfwordsequence::aligninfo although it comes from some totally other corner of the system
std::vector<word> words;
std::vector<aligninfo> align;
@ -593,7 +593,7 @@ private:
#if 1 // multiple /sil/ -> log this (as we are not sure whether this is actually proper--probably it is)
if (numsilunits > 1)
{
fprintf (stderr, "backpointers: lattice '%S', edge %d has %d /sil/ phonemes\n", L.getkey(), j, numsilunits);
fprintf (stderr, "backpointers: lattice '%S', edge %d has %d /sil/ phonemes\n", L.getkey(), j, (int)numsilunits);
fprintf (stderr, "alignments: :");
foreach_index (a, aligntokens)
{
@ -643,9 +643,9 @@ private:
double bestpathlattice (const std::vector<float> & edgeacscores, std::vector<double> & logpps,
const float lmf, const float wp, const float amf) const;
static float lattice::alignedge (const_array_ref<aligninfo> units, const msra::asr::simplesenonehmm & hset,
const msra::math::ssematrixbase & logLLs, msra::math::ssematrixbase & gammas,
size_t edgeindex, const bool returnsenoneids, array_ref<unsigned short> thisedgealignments);
static float alignedge (const_array_ref<aligninfo> units, const msra::asr::simplesenonehmm & hset,
const msra::math::ssematrixbase & logLLs, msra::math::ssematrixbase & gammas,
size_t edgeindex, const bool returnsenoneids, array_ref<unsigned short> thisedgealignments);
const_array_ref<aligninfo> getaligninfo (size_t j) const { size_t begin = (size_t) edges[j].firstalign; size_t end = j+1 < edges.size() ? (size_t) edges[j+1].firstalign : align.size(); return const_array_ref<aligninfo> (align.data() + begin, end - begin); }
@ -674,9 +674,9 @@ private:
const std::vector<float> & transcriptunigrams, const msra::math::ssematrixbase & logLLs,
const msra::asr::simplesenonehmm & hset, const float lmf, const float wp, const float amf);
static float lattice::forwardbackwardedge (const_array_ref<aligninfo> units, const msra::asr::simplesenonehmm & hset,
const msra::math::ssematrixbase & logLLs, msra::math::ssematrixbase & gammas,
size_t edgeindex);
static float forwardbackwardedge (const_array_ref<aligninfo> units, const msra::asr::simplesenonehmm & hset,
const msra::math::ssematrixbase & logLLs, msra::math::ssematrixbase & gammas,
size_t edgeindex);
double forwardbackwardlattice (const std::vector<float> & edgeacscores, parallelstate & parallelstate,
std::vector<double> & logpps, std::vector<double> & logalphas, std::vector<double> & logbetas,
@ -747,7 +747,7 @@ public:
for (size_t j = 0; j < info.numedges; j++)
totaledgeframes += nodes[edges[j].E].t - (size_t) nodes[edges[j].S].t;
fprintf (stderr, "lattice: read %d nodes, %d edges, %d units, %d frames, %.1f edges/node, %.1f units/edge, %.1f frames/edge, density %.1f\n",
info.numnodes, info.numedges, align.size(), info.numframes,
(int)info.numnodes, (int)info.numedges, (int)align.size(), (int)info.numframes,
info.numedges / (double) info.numnodes, align.size() / (double) info.numedges, totaledgeframes / (double) info.numedges, totaledgeframes / (double) info.numframes);
}
@ -895,7 +895,7 @@ public:
#if 1 // post-bugfix for incorrect inference of spunit
if (info.impliedspunitid != SIZE_MAX && info.impliedspunitid >= idmap.size()) // we have buggy lattices like that--what do they mean??
{
fprintf (stderr, "fread: detected buggy spunit id %d which is out of range (%d entries in map)\n", info.impliedspunitid, idmap.size());
fprintf (stderr, "fread: detected buggy spunit id %d which is out of range (%d entries in map)\n", (int)info.impliedspunitid, (int)idmap.size());
throw std::runtime_error ("fread: out of bounds spunitid");
}
#endif
@ -949,7 +949,7 @@ public:
k += skipscoretokens;
uniquealignments++;
}
fprintf (stderr, "fread: mapped %d unique alignments\n", uniquealignments);
fprintf (stderr, "fread: mapped %d unique alignments\n", (int)uniquealignments);
}
if (info.impliedspunitid != spunit)
{
@ -1078,7 +1078,7 @@ class archive
mutable size_t currentarchiveindex; // which archive is open
mutable auto_file_ptr f; // cached archive file handle of currentarchiveindex
hash_map<std::wstring,latticeref> toc; // [key] -> (file, offset) --table of content (.toc file)
unordered_map<std::wstring, latticeref> toc; // [key] -> (file, offset) --table of content (.toc file)
public:
// construct = open the archive
//archive() : currentarchiveindex (SIZE_MAX) {}
@ -1091,13 +1091,13 @@ public:
{
if (tocpaths.empty()) // nothing to read--keep silent
return;
fprintf (stderr, "archive: opening %d lattice-archive TOC files ('%S' etc.)..", tocpaths.size(), tocpaths[0].c_str());
fprintf (stderr, "archive: opening %d lattice-archive TOC files ('%S' etc.)..", (int)tocpaths.size(), tocpaths[0].c_str());
foreach_index (i, tocpaths)
{
fprintf (stderr, ".");
open (tocpaths[i]);
}
fprintf (stderr, " %d total lattices referenced in %d archive files\n", toc.size(), archivepaths.size());
fprintf (stderr, " %d total lattices referenced in %d archive files\n", (int)toc.size(), (int)archivepaths.size());
}
// open an archive
@ -1133,7 +1133,12 @@ public:
throw std::runtime_error ("open: invalid TOC line (empty archive pathname): " + std::string (line));
char c;
uint64_t offset;
#ifdef _WIN32
if (sscanf_s (q, "[%I64u]%c", &offset, &c, sizeof (c)) != 1)
#else
if (sscanf (q, "[%" PRIu64 "]%c", &offset, &c) != 1)
#endif
throw std::runtime_error ("open: invalid TOC line (bad [] expression): " + std::string (line));
if (!toc.insert (make_pair (key, latticeref (offset, archiveindex))).second)
throw std::runtime_error ("open: TOC entry leads to duplicate key: " + std::string (line));

Просмотреть файл

@ -25,7 +25,7 @@ static void checkoverflow (size_t fieldval, size_t targetval, const char * field
if (fieldval != targetval)
{
char buf[1000];
sprintf_s (buf, "lattice: bit field %s too small for value 0x%x (cut from 0x%x)", fieldname, targetval, fieldval);
sprintf_s (buf, sizeof(buf), "lattice: bit field %s too small for value 0x%x (cut from 0x%x)", fieldname, (unsigned int)targetval, (unsigned int)fieldval);
throw std::runtime_error (buf);
}
}

Просмотреть файл

@ -209,7 +209,7 @@ public:
{
firstvalidepochstartframe = source.firstvalidglobalts (epochstartframe); // epochstartframe may fall between utterance boundaries; this gets us the first valid boundary
fprintf (stderr, "minibatchiterator: epoch %d: frames [%d..%d] (first utterance at frame %d), data subset %d of %d, with %d datapasses\n",
epoch, epochstartframe, epochendframe, firstvalidepochstartframe, subsetnum, numsubsets, datapasses);
(int)epoch, (int)epochstartframe, (int)epochendframe, (int)firstvalidepochstartframe, (int)subsetnum, (int)numsubsets, (int)datapasses);
mbstartframe = firstvalidepochstartframe;
datapass = 0;
fillorclear(); // get the first batch
@ -228,7 +228,7 @@ public:
{
firstvalidepochstartframe = source.firstvalidglobalts (epochstartframe); // epochstartframe may fall between utterance boundaries; this gets us the first valid boundary
fprintf (stderr, "minibatchiterator: epoch %d: frames [%d..%d] (first utterance at frame %d), data subset %d of %d, with %d datapasses\n",
epoch, epochstartframe, epochendframe, firstvalidepochstartframe, subsetnum, numsubsets, datapasses);
(int)epoch, (int)epochstartframe, (int)epochendframe, (int)firstvalidepochstartframe, (int)subsetnum, (int)numsubsets, (int)datapasses);
mbstartframe = firstvalidepochstartframe;
datapass = 0;
fillorclear(); // get the first batch
@ -253,7 +253,7 @@ public:
{
mbstartframe = firstvalidepochstartframe;
datapass++;
fprintf (stderr, "\nminibatchiterator: entering %d-th repeat pass through the data\n", datapass+1);
fprintf (stderr, "\nminibatchiterator: entering %d-th repeat pass through the data\n", (int)(datapass+1));
}
fillorclear();
}

Просмотреть файл

@ -12,7 +12,9 @@
#include <stdio.h>
#include <vector>
#include <algorithm>
#ifndef __unix__
#include "ssematrix.h" // for matrix type
#endif
namespace msra { namespace dbn {
@ -246,7 +248,7 @@ public:
retries++;
}
}
fprintf (stderr, "randomordering: %d retries for %d elements (%.1f%%) to ensure window condition\n", retries, map.size(), 100.0 * retries / map.size());
fprintf (stderr, "randomordering: %d retries for %d elements (%.1f%%) to ensure window condition\n", (int)retries, (int)map.size(), 100.0 * retries / map.size());
// ensure the window condition
foreach_index (t, map) assert ((size_t) t <= map[t] + randomizationrange/2 && map[t] < (size_t) t + randomizationrange/2);
#if 1 // and a live check since I don't trust myself here yet

Просмотреть файл

@ -12,7 +12,7 @@
#include "fileutil.h" // for opening/reading the ARPA file
#include <vector>
#include <string>
#include <hash_map>
#include <unordered_map>
#include <algorithm> // for various sort() calls
#include <math.h>
@ -22,8 +22,9 @@ namespace msra { namespace lm {
// core LM interface -- LM scores are accessed through this exclusively
// ===========================================================================
interface ILM // generic interface -- mostly the score() function
class ILM // generic interface -- mostly the score() function
{
public:
virtual double score (const int * mgram, int m) const = 0;
virtual bool oov (int w) const = 0; // needed for perplexity calculation
// ... TODO (?): return true/false to indicate whether anything changed.
@ -31,8 +32,9 @@ interface ILM // generic interface -- mostly the score() function
virtual void adapt (const int * data, size_t m) = 0; // (NULL,M) to reset, (!NULL,0) to flush
// iterator for composing models --iterates in increasing order w.r.t. w
interface IIter
class IIter
{
public:
virtual operator bool() const = 0; // has iterator not yet reached end?
// ... TODO: ensure iterators do not return OOVs w.r.t. user symbol table
// (It needs to be checked which LM type's iterator currently does.)
@ -85,7 +87,7 @@ static inline double invertlogprob (double logP) { return logclip (1.0 - exp (lo
// CSymbolSet -- a simple symbol table
// ===========================================================================
// compare function to allow char* as keys (without, hash_map will correctly
// compare function to allow char* as keys (without, unordered_map will correctly
// compute a hash key from the actual strings, but then compare the pointers
// -- duh!)
struct less_strcmp : public binary_function<const char *, const char *, bool>
@ -94,7 +96,7 @@ struct less_strcmp : public binary_function<const char *, const char *, bool>
{ return strcmp (_Left, _Right) < 0; }
};
class CSymbolSet : public stdext::hash_map<const char *, int, stdext::hash_compare<const char*,less_strcmp>>
class CSymbolSet : public std::unordered_map<const char *, int, std::hash<const char*>, less_strcmp>
{
vector<const char *> symbols; // the symbols
@ -106,14 +108,14 @@ public:
void clear()
{
foreach_index (i, symbols) free ((void*) symbols[i]);
hash_map::clear();
unordered_map::clear();
}
// operator[key] on a 'const' object
// get id for an existing word, returns -1 if not existing
int operator[] (const char * key) const
{
hash_map<const char *,int>::const_iterator iter = find (key);
unordered_map<const char *, int>::const_iterator iter = find(key);
return (iter != end()) ? iter->second : -1;
}
@ -121,14 +123,18 @@ public:
// determine unique id for a word ('key')
int operator[] (const char * key)
{
hash_map<const char *,int>::const_iterator iter = find (key);
unordered_map<const char *, int>::const_iterator iter = find(key);
if (iter != end())
return iter->second;
// create
const char * p = _strdup (key);
if (!p)
#ifdef _WIN32
throw std::bad_exception ("CSymbolSet:id string allocation failure");
#else
throw std::bad_exception ();
#endif
try
{
int id = (int) symbols.size();
@ -274,7 +280,7 @@ class mgram_map
{
typedef unsigned int index_t; // (-> size_t when we really need it)
//typedef size_t index_t; // (tested once, seems to work)
static const index_t nindex = (index_t) -1; // invalid index
static const index_t nindex; // invalid index
// entry [m][i] is first index of children in level m+1, entry[m][i+1] the end.
int M; // order, e.g. M=3 for trigram
std::vector<std::vector<index_t>> firsts; // [M][i] ([0] = zerogram = root)
@ -1124,7 +1130,7 @@ public:
void read (const std::wstring & pathname, SYMMAP & userSymMap, bool filterVocabulary, int maxM)
{
int lineNo = 0;
msra::basetypes::auto_file_ptr f = fopenOrDie (pathname, L"rbS");
msra::basetypes::auto_file_ptr f(fopenOrDie (pathname, L"rbS"));
fprintf (stderr, "read: reading %S", pathname.c_str());
filename = pathname; // (keep this info for debugging)
@ -1769,7 +1775,7 @@ protected:
mcounts.push_back (mmap.create (newkey, mmapCache), count); // store 'count' under 'key'
}
}
fprintf (stderr, " %d %d-grams", mcounts.size (m), m);
fprintf (stderr, " %d %d-grams", (int)mcounts.size (m), m);
}
// remove used up tokens from the buffer
@ -1977,10 +1983,11 @@ public:
//// set prune value to 0 3 3
//setMinObs (iMinObs);
for (size_t i = 0; i < minObs.size(); i++)
{
MESSAGE("minObs %d: %d.", i, minObs[i]);
}
// TODO: Re-enable when MESSAGE definition is provided (printf?)
// for (size_t i = 0; i < minObs.size(); i++)
// {
// MESSAGE("minObs %d: %d.", i, minObs[i]);
// }
estimate (startId, minObs, dropWord);
@ -2027,7 +2034,7 @@ public:
while (M > 0 && counts.size (M) == 0) resize (M-1);
for (int m = 1; m <= M; m++)
fprintf (stderr, "estimate: read %d %d-grams\n", counts.size (m), m);
fprintf (stderr, "estimate: read %d %d-grams\n", (int)counts.size (m), m);
// === Kneser-Ney smoothing
// This is a strange algorithm.
@ -2197,8 +2204,8 @@ public:
for (int m = 1; m <= M; m++)
{
fprintf (stderr, "estimate: %d-grams after pruning: %d out of %d (%.1f%%)\n", m,
numMGrams[m], counts.size (m),
100.0 * numMGrams[m] / max (counts.size (m), 1));
numMGrams[m], (int)counts.size (m),
100.0 * numMGrams[m] / max (counts.size (m), size_t(1)));
}
// ensure M reflects the actual order of read data after pruning
@ -2282,6 +2289,9 @@ public:
}
}
double dcount;
double dP;
// pruned case
if (count == 0) // this entry was pruned before
goto skippruned;
@ -2314,7 +2324,7 @@ public:
}
// estimate discounted probability
double dcount = count; // "modified Kneser-Ney" discounting
dcount = count; // "modified Kneser-Ney" discounting
if (count >= 3) dcount -= d3[m];
else if (count == 2) dcount -= d2[m];
else if (count == 1) dcount -= d1[m];
@ -2323,7 +2333,7 @@ public:
if (histCount == 0)
RuntimeError ("estimate: unexpected 0 denominator");
double dP = dcount / histCount;
dP = dcount / histCount;
// and this is the discounted probability value
{
// Actually, 'key' uses a "mapped" word ids, while create()
@ -2412,7 +2422,7 @@ skippruned:; // m-gram was pruned
updateOOVScore();
fprintf (stderr, "estimate: done");
for (int m = 1; m <= M; m++) fprintf (stderr, ", %d %d-grams", logP.size (m), m);
for (int m = 1; m <= M; m++) fprintf (stderr, ", %d %d-grams", (int)logP.size (m), m);
fprintf (stderr, "\n");
}
};
@ -2521,7 +2531,7 @@ skipMGram:
wstring dir, file;
splitpath (clonepath, dir, file); // we allow relative paths in the file
msra::basetypes::auto_file_ptr f = fopenOrDie (clonepath, L"rbS");
msra::basetypes::auto_file_ptr f(fopenOrDie (clonepath, L"rbS"));
std::string line = fgetline (f);
if (line != "#clone")
throw runtime_error ("read: invalid header line " + line);

Просмотреть файл

@ -7,9 +7,11 @@
#pragma once
#ifndef __unix__
#include <Windows.h>
#include <stdexcept>
#include "pplhelpers.h"
#endif
#include <stdexcept>
#include "simple_checked_arrays.h"
#include "basetypes.h" // for FormatWin32Error

Просмотреть файл

@ -8,8 +8,9 @@
#pragma once
#ifndef __unix__
#include <ppl.h>
#endif
namespace msra { namespace parallel {
// ===========================================================================

Просмотреть файл

@ -12,7 +12,9 @@
#include "basetypes.h"
#include "minibatchiterator.h"
#include "latticearchive.h"
#ifdef _WIN32
#include "simplethread.h"
#endif
#include <deque>
#include <stdexcept>

Просмотреть файл

@ -9,7 +9,9 @@
#pragma once
#include "basetypes.h" // for attempt()
#ifdef _WIN32
#include "numahelpers.h" // for NUMA allocation
#endif
#include "minibatchsourcehelpers.h"
#include "minibatchiterator.h"
#include "biggrowablevectors.h"
@ -37,9 +39,13 @@ namespace msra { namespace dbn {
msra::dbn::matrix * newblock() const
{
// we stripe the data across NUMA nodes as to not fill up one node with the feature data
#ifdef _WIN32
msra::numa::overridenode ((int) msra::numa::getmostspaciousnumanode());
#endif
msra::dbn::matrix * res = new msra::dbn::matrix (m, elementsperblock);
#ifdef _WIN32
msra::numa::overridenode (-1); // note: we really should reset it also in case of failure
#endif
return res;
}
@ -100,7 +106,7 @@ namespace msra { namespace dbn {
size_t blockid = t0 / elementsperblock;
assert (blockid * elementsperblock == t0);
assert (blocks[blockid]);
fprintf (stderr, "recoverblock: releasing feature block %d [%d..%d)\n", blockid, t0, t0 + elementsperblock -1);
fprintf (stderr, "recoverblock: releasing feature block %d [%d..%d)\n", (int)blockid, (int)t0, (int)(t0 + elementsperblock -1));
blocks[blockid].reset(); // free the memory
}
void recoverblock (size_t t0) // t0=block start time
@ -109,7 +115,7 @@ namespace msra { namespace dbn {
size_t blockid = t0 / elementsperblock;
assert (blockid * elementsperblock == t0);
assert (!blocks[blockid]);
fprintf (stderr, "recoverblock: recovering feature block %d [%d..%d)\n", blockid, t0, t0 + elementsperblock -1);
fprintf (stderr, "recoverblock: recovering feature block %d [%d..%d)\n", (int)blockid, (int)t0, (int)(t0 + elementsperblock -1));
blocks[blockid].reset (newblock());
msra::dbn::matrix & block = *blocks[blockid];
fsetpos (f, blockid * block.sizeinpagefile());
@ -163,7 +169,7 @@ namespace msra { namespace dbn {
// finish off last block
flushlastblock();
fflushOrDie (f);
fprintf (stderr, "biggrowablevectorarray: disk backup store created, %d frames, %ull bytes\n", (int) n, fgetpos (f));
fprintf (stderr, "biggrowablevectorarray: disk backup store created, %d frames, %lu bytes\n", (int) n, fgetpos (f));
fclose (f);
foreach_index (i, blocks) assert (!blocks[i]); // ensure we flushed
assert (inmembegin == inmemend); // nothing in cache
@ -265,7 +271,7 @@ namespace msra { namespace dbn {
// - implement block-wise paging directly from HTK feature files through htkfeatreader
featkind.clear();
std::vector<float> frame;
fprintf (stderr, "minibatchframesource: reading %d utterances..", infiles.size());
fprintf (stderr, "minibatchframesource: reading %d utterances..", (int)infiles.size());
size_t numclasses = 0; // number of units found (actually max id +1)
size_t notfound = 0; // number of entries missing in MLF
msra::asr::htkfeatreader reader; // feature reader
@ -281,7 +287,12 @@ namespace msra { namespace dbn {
wstring key;
if (!labels.empty()) // empty means unsupervised mode (don't load any)
{
#ifdef _WIN32
key = regex_replace ((wstring)ppath, wregex (L"\\.[^\\.\\\\/:]*$"), wstring()); // delete extension (or not if none)
#endif
#ifdef __unix__
key = removeExtension(basename(ppath));
#endif
if (labels.find (key) == labels.end())
{
if (notfound < 5)
@ -309,7 +320,7 @@ namespace msra { namespace dbn {
size_t labframes = labseq.empty() ? 0 : (labseq[labseq.size()-1].firstframe + labseq[labseq.size()-1].numframes);
if (abs ((int) labframes - (int) feat.cols()) > 0)
{
fprintf (stderr, "\nminibatchframesource: %d-th file has small duration mismatch (%d in label vs. %d in feat file), skipping: %S", i, labframes, feat.cols(), key.c_str());
fprintf (stderr, "\nminibatchframesource: %d-th file has small duration mismatch (%d in label vs. %d in feat file), skipping: %S", i, (int)labframes, (int)feat.cols(), key.c_str());
notfound++;
continue; // skip this utterance at all
}
@ -346,7 +357,7 @@ namespace msra { namespace dbn {
if (e.classid != (CLASSIDTYPE) e.classid)
throw std::runtime_error ("CLASSIDTYPE has too few bits");
classids.push_back ((CLASSIDTYPE) e.classid);
numclasses = max (numclasses, 1u + e.classid);
numclasses = max (numclasses, (size_t)(1u + e.classid));
}
}
if (vdim == 0)
@ -364,10 +375,10 @@ namespace msra { namespace dbn {
assert (labels.empty() || numframes == classids.size());
if ((vdim != 0 && numframes != frames.size()) || (!labels.empty() && numframes != classids.size()))
throw std::runtime_error ("minibatchframesource: numframes variable screwup");
fprintf (stderr, " %d frames read from %d utterances; %d classes\n", numframes, infiles.size(), numclasses);
fprintf (stderr, " %d frames read from %d utterances; %d classes\n", (int)numframes, (int)infiles.size(), (int)numclasses);
if (notfound > 0)
{
fprintf (stderr, "minibatchframesource: %d files out of %d not found in label set\n", notfound, infiles.size());
fprintf (stderr, "minibatchframesource: %d files out of %d not found in label set\n", (int)notfound, (int)infiles.size());
if (notfound > infiles.size() / 2)
throw std::runtime_error ("minibatchframesource: too many files not found in label set--assuming broken configuration\n");
}
@ -421,7 +432,7 @@ namespace msra { namespace dbn {
const size_t te = min (ts + framesrequested, totalframes()); // do not go beyond sweep boundary
assert (te > ts);
if (verbosity >= 2)
fprintf (stderr, "getbatch: frames [%d..%d] in sweep %d\n", ts, te-1, sweep);
fprintf (stderr, "getbatch: frames [%d..%d] in sweep %d\n", (int)ts, (int)(te-1), (int)sweep);
// get random sequence (each time index occurs exactly once)
// If the sweep changes, this will re-cache the sequence. We optimize for rare, monotonous sweep changes.
@ -543,7 +554,7 @@ namespace msra { namespace dbn {
}
fprintf (stderr, "minibatchframesourcemulti: reading %d feature sets and %d label sets...", infiles.size(),labels.size());
fprintf (stderr, "minibatchframesourcemulti: reading %d feature sets and %d label sets...", (int)infiles.size(), (int)labels.size());
foreach_index (m, infiles)
{
@ -567,7 +578,12 @@ namespace msra { namespace dbn {
{
if (!labels[0].empty()) // empty means unsupervised mode (don't load any)
{
#ifdef _WIN32
key = regex_replace ((wstring)ppath, wregex (L"\\.[^\\.\\\\/:]*$"), wstring()); // delete extension (or not if none)
#endif
#ifdef __unix__
key = removeExtension(basename(ppath));
#endif
if (labels[0].find (key) == labels[0].end())
{
if (notfound < 5)
@ -595,7 +611,7 @@ namespace msra { namespace dbn {
size_t labframes = labseq.empty() ? 0 : (labseq[labseq.size()-1].firstframe + labseq[labseq.size()-1].numframes);
if (abs ((int) labframes - (int) feat.cols()) > 0)
{
fprintf (stderr, "\nminibatchframesourcemulti: %d-th file has small duration mismatch (%d in label vs. %d in feat file), skipping: %S", i, labframes, feat.cols(), key.c_str());
fprintf (stderr, "\nminibatchframesourcemulti: %d-th file has small duration mismatch (%d in label vs. %d in feat file), skipping: %S", i, (int)labframes, (int)feat.cols(), key.c_str());
notfound++;
continue; // skip this utterance at all
}
@ -645,7 +661,7 @@ namespace msra { namespace dbn {
if (e.classid != (CLASSIDTYPE) e.classid)
throw std::runtime_error ("CLASSIDTYPE has too few bits");
classids[j].push_back ((CLASSIDTYPE) e.classid);
numclasses[j] = max (numclasses[j], 1u + e.classid);
numclasses[j] = max (numclasses[j], (size_t)(1u + e.classid));
}
}
if (vdim[m] == 0)
@ -676,12 +692,12 @@ namespace msra { namespace dbn {
if (m==0)
{
foreach_index (j, numclasses)
fprintf (stderr, "\nminibatchframesourcemulti: read label set %d: %d classes\n", j, numclasses[j]);
fprintf (stderr, "\nminibatchframesourcemulti: read label set %d: %d classes\n", j, (int)numclasses[j]);
}
fprintf (stderr, "\nminibatchframesourcemulti: feature set %d: %d frames read from %d utterances\n", m, pframes[m]->size(), infiles[m].size());
fprintf (stderr, "\nminibatchframesourcemulti: feature set %d: %d frames read from %d utterances\n", m, (int)pframes[m]->size(), (int)infiles[m].size());
if (notfound > 0)
{
fprintf (stderr, "minibatchframesourcemulti: %d files out of %d not found in label set\n", notfound, infiles[m].size());
fprintf (stderr, "minibatchframesourcemulti: %d files out of %d not found in label set\n", (int)notfound, (int)infiles[m].size());
if (notfound > infiles[m].size() / 2)
throw std::runtime_error ("minibatchframesourcemulti: too many files not found in label set--assuming broken configuration\n");
}
@ -741,7 +757,7 @@ namespace msra { namespace dbn {
const size_t te = min (ts + framesrequested, totalframes()); // do not go beyond sweep boundary
assert (te > ts);
if (verbosity >= 2)
fprintf (stderr, "getbatch: frames [%d..%d] in sweep %d\n", ts, te-1, sweep);
fprintf (stderr, "getbatch: frames [%d..%d] in sweep %d\n", (int)ts, (int)(te-1), (int)sweep);
// get random sequence (each time index occurs exactly once)
// If the sweep changes, this will re-cache the sequence. We optimize for rare, monotonous sweep changes.

Просмотреть файл

@ -64,7 +64,7 @@ public: // (TODO: better encapsulation)
transP() : numstates (0) {}
};
std::vector<transP> transPs; // the transition matrices --TODO: finish this
std::hash_map<std::string,size_t> transPmap; // [transPname] -> index into transPs[]
std::unordered_map<std::string, size_t> transPmap; // [transPname] -> index into transPs[]
public:
// get an hmm by index
const hmm & gethmm (size_t i) const { return hmms[i]; }
@ -216,7 +216,7 @@ public:
}
}
fprintf (stderr, "simplesenonehmm: %d units with %d unique HMMs, %d tied states, and %d trans matrices read\n",
symmap.size(), hmms.size(), statemap.size(), transPs.size());
(int)symmap.size(), (int)hmms.size(), (int)statemap.size(), (int)transPs.size());
}
// exposed so we can pass it to the lattice reader, which maps the symbol ids for us

Просмотреть файл

@ -9,7 +9,9 @@
#pragma once
#include "basetypes.h"
#ifdef _WIN32
#include <process.h> // for _beginthread()
#endif
namespace msra { namespace util {

Просмотреть файл

@ -8,7 +8,12 @@
#pragma once
#ifdef _WIN32
#include <intrin.h> // for intrinsics
#endif
#ifdef __unix__
#include <x86intrin.h>
#endif
namespace msra { namespace math {

Просмотреть файл

@ -13,11 +13,14 @@
#include "simple_checked_arrays.h" // ... for dotprod(); we can eliminate this I believe
#include "ssefloat4.h"
#include <stdexcept>
#include "numahelpers.h"
#ifndef __unix__
#include <ppl.h>
#include "pplhelpers.h"
#include "numahelpers.h"
#endif
#include "fileutil.h" // for saving and reading matrices
#include <limits> // for NaN
#include <malloc.h>
namespace msra { namespace math {
@ -275,7 +278,7 @@ public:
bool addtoresult, const float thisscale, const float weight)
{
assert (a.size() == b.size());
assert ((15 & (int) &a[0]) == 0); assert ((15 & (int) &b[0]) == 0); // enforce SSE alignment
assert ((15 & reinterpret_cast<uintptr_t>(&a[0])) == 0); assert ((15 & reinterpret_cast<uintptr_t>(&b[0])) == 0); // enforce SSE alignment
size_t nlong = (a.size() + 3) / 4; // number of SSE elements
const msra::math::float4 * pa = (const msra::math::float4 *) &a[0];
@ -310,9 +313,9 @@ public:
// for (size_t k = 0; k < 4; k++)
// dotprod (row, const_array_ref<float> (&cols4[k * cols4stride], cols4stride), usij[k * usijstride]);
assert ((15 & (int) &row[0]) == 0);
assert ((15 & (int) &cols4[0]) == 0);
assert ((15 & (int) &cols4[cols4stride]) == 0);
assert ((15 & reinterpret_cast<uintptr_t>(&row[0])) == 0);
assert ((15 & reinterpret_cast<uintptr_t>(&cols4[0])) == 0);
assert ((15 & reinterpret_cast<uintptr_t>(&cols4[cols4stride])) == 0);
//assert (cols4stride * 4 == cols4.size()); // (passed in one vector with 4 columns stacked on top of each other)
//assert (row.size() * 4 == cols4.size()); // this assert is no longer appropriate because of further breaking into blocks
@ -389,6 +392,7 @@ public:
matprod_mtm (Mt, 0, Mt.cols(), V);
}
#ifdef _WIN32
void parallel_matprod_mtm (const ssematrixbase & Mt, const ssematrixbase & V)
{
msra::parallel::foreach_index_block (Mt.cols(), Mt.cols(), 1, [&] (size_t i0, size_t i1)
@ -396,6 +400,7 @@ public:
matprod_mtm (Mt, i0, i1, V);
});
}
#endif
// swap data of i-th column and j-th column
void swapcolumn (size_t i, size_t j)
@ -801,6 +806,7 @@ public:
scaleandaddmatprod_mtm (thisscale, Mt, 0, Mt.cols(), V);
}
#ifdef _WIN32
void parallel_scaleandaddmatprod_mtm (const float thisscale, const ssematrixbase & Mt, const ssematrixbase & V)
{
#if 0
@ -813,6 +819,7 @@ public:
});
#endif
}
#endif
// same as matprod_mtm except result is added to result matrix instead of replacing it
// For all comments, see matprod_mtm.
@ -912,6 +919,7 @@ public:
// to = this'
void transpose (ssematrixbase & to) const { transposecolumns (to, 0, cols()); }
#ifdef _WIN32
void parallel_transpose (ssematrixbase & to) const
{
msra::parallel::foreach_index_block (cols(), cols(), 4/*align*/, [&] (size_t j0, size_t j1)
@ -925,6 +933,7 @@ public:
throw std::logic_error ("parallel_transpose: post-condition check failed--you got it wrong, man!");
#endif
}
#endif
// transpose columns [j0,j1) to rows [j0,j1) of 'to'
void transposecolumns (ssematrixbase & to, size_t j0, size_t j1) const
@ -1149,7 +1158,7 @@ public:
foreach_coord (i, j, us)
if (std::isnan (us(i,j)))
{
fprintf (stderr, "hasnan: NaN detected at %s (%d,%d)\n", name, i, j);
fprintf (stderr, "hasnan: NaN detected at %s (%d,%d)\n", name, (int)i, (int)j);
return true;
}
#endif
@ -1200,7 +1209,7 @@ class ssematrixfrombuffer : public ssematrixbase
{
void operator= (const ssematrixfrombuffer &); ssematrixfrombuffer (const ssematrixfrombuffer &); // base cannot be assigned except by move
public:
ssematrixfrombuffer() { clear(); }
ssematrixfrombuffer() { this->clear(); }
// instantiate from a float vector --buffer must be SSE-aligned
template<class VECTOR> ssematrixfrombuffer (VECTOR & buffer, size_t n, size_t m) : ssematrixbase (buffer, n, m) {}
@ -1233,10 +1242,10 @@ public:
assert (other.empty() || j0 + m <= other.cols());
if (!other.empty() && j0 + m > other.cols()) // (runtime check to be sure--we use this all the time)
throw std::logic_error ("ssematrixstriperef: stripe outside original matrix' dimension");
p = other.empty() ? NULL : &other(0,j0);
numrows = other.rows();
numcols = m;
colstride = other.getcolstride();
this->p = other.empty() ? NULL : &other(0,j0);
this->numrows = other.rows();
this->numcols = m;
this->colstride = other.getcolstride();
}
// only assignment is by rvalue reference
@ -1255,14 +1264,20 @@ public:
template<class ssematrixbase> class ssematrix : public ssematrixbase
{
// helpers for SSE-compatible memory allocation
static __declspec(noreturn) void failed (size_t nbytes) { static/*not thread-safe--for diagnostics only*/ char buf[80] = { 0 }; sprintf_s (buf, "allocation of SSE vector failed (%d bytes)", nbytes); throw std::bad_exception (buf); }
#if 1 // TODO: move to separate header file numahelpers.h
template<typename T> static T * new_sse (size_t nbytes) { T * pv = (T *) msra::numa::malloc (nbytes * sizeof (T), 16); if (pv) return pv; failed (nbytes * sizeof (T)); }
static void delete_sse (void * p) { if (p) msra::numa::free (p); }
#else
#ifdef _MSC_VER
static __declspec_noreturn void failed(size_t nbytes) { static/*not thread-safe--for diagnostics only*/ char buf[80] = { 0 }; sprintf_s(buf, "allocation of SSE vector failed (%d bytes)", nbytes); throw std::bad_exception(buf); }
#endif
#ifdef __unix__
static void failed (size_t nbytes) { static/*not thread-safe--for diagnostics only*/ char buf[80] = { 0 }; sprintf_s (buf, sizeof(buf), "allocation of SSE vector failed (%d bytes)", (int)nbytes); throw std::bad_exception (); }
#endif
#ifdef _WIN32
template<typename T> static T * new_sse (size_t nbytes) { T * pv = (T *) _aligned_malloc (nbytes * sizeof (T), 16); if (pv) return pv; failed (nbytes * sizeof (T)); }
static void delete_sse (void * p) { if (p) _aligned_free (p); }
#endif
#ifdef __unix__
template<typename T> static T * new_sse (size_t nbytes) { T * pv = (T *) _mm_malloc (nbytes * sizeof (T),16); if (pv) return pv; failed (nbytes * sizeof (T)); }
static void delete_sse (void * p) { if (p) _mm_free (p); }
#endif
// helper to assign a copy from another matrix
void assign (const ssematrixbase & other)
@ -1272,18 +1287,18 @@ template<class ssematrixbase> class ssematrix : public ssematrixbase
};
public:
// construction
ssematrix() { clear(); }
ssematrix (size_t n, size_t m) { clear(); resize (n, m); }
ssematrix (size_t n) { clear(); resize (n, 1); } // vector
ssematrix (const ssematrix & other) { clear(); assign (other); }
ssematrix (const ssematrixbase & other) { clear(); assign (other); }
ssematrix (ssematrix && other) { move (other); }
ssematrix (const std::vector<float> & other) { clear(); resize (other.size(), 1); foreach_index (k, other) (*this)[k] = other[k]; }
ssematrix() { this->clear(); }
ssematrix (size_t n, size_t m) { this->clear(); resize (n, m); }
ssematrix (size_t n) { this->clear(); resize (n, 1); } // vector
ssematrix (const ssematrix & other) { this->clear(); assign (other); }
ssematrix (const ssematrixbase & other) { this->clear(); assign (other); }
ssematrix (ssematrix && other) { this->move (other); }
ssematrix (const std::vector<float> & other) { this->clear(); resize (other.size(), 1); foreach_index (k, other) (*this)[k] = other[k]; }
// construct elementwise with a function f(i,j)
template<typename FUNCTION> ssematrix (size_t n, size_t m, const FUNCTION & f)
{
clear();
this->clear();
resize (n, m);
auto & us = *this;
foreach_coord (i, j, us)
@ -1291,12 +1306,12 @@ public:
}
// destructor
~ssematrix() { delete_sse (p); }
~ssematrix() { delete_sse (this->p); }
// assignment
ssematrix & operator= (const ssematrix & other) { assign (other); return *this; }
ssematrix & operator= (const ssematrixbase & other) { assign (other); return *this; }
ssematrix & operator= (ssematrix && other) { delete_sse(p); move (other); return *this; }
ssematrix & operator= (ssematrix && other) { delete_sse(this->p); move (other); return *this; }
void swap (ssematrix & other) throw() { ssematrixbase::swap (other); }
@ -1304,23 +1319,23 @@ public:
// One or both dimensions can be 0, for special purposes.
void resize (size_t n, size_t m)
{
if (n == numrows && m == numcols)
if (n == this->numrows && m == this->numcols)
return; // no resize needed
const size_t newcolstride = (n + 3) & ~3; // pad to multiples of four floats (required SSE alignment)
const size_t totalelem = newcolstride * m;
//fprintf (stderr, "resize (%d, %d) allocating %d elements\n", n, m, totalelem);
float * pnew = totalelem > 0 ? new_sse<float> (totalelem) : NULL;
::swap (p, pnew);
::swap (this->p, pnew);
delete_sse (pnew); // pnew is now the old p
numrows = n; numcols = m;
colstride = newcolstride;
this->numrows = n; this->numcols = m;
this->colstride = newcolstride;
// touch the memory to ensure the page is created
for (size_t offset = 0; offset < totalelem; offset += 4096 / sizeof (float))
p[offset] = 0.0f; //nan;
this->p[offset] = 0.0f; //nan;
// clear padding elements (numrows <= i < colstride) to 0.0 for SSE optimization
for (size_t j = 0; j < numcols; j++)
for (size_t i = numrows; i < colstride; i++)
p[j * colstride + i] = 0.0f;
for (size_t j = 0; j < this->numcols; j++)
for (size_t i = this->numrows; i < this->colstride; i++)
this->p[j * this->colstride + i] = 0.0f;
#if 1 // for debugging: set all elements to 0
// We keep this code alive because allocations are supposed to be done at the start only.
auto & us = *this;
@ -1335,8 +1350,8 @@ public:
void resizeonce (size_t n, size_t m)
{
#if 1 // BUGBUG: at end of epoch, resizes are OK... so we log but allow them
if (!empty() && (n != numrows || m != numcols))
fprintf (stderr, "resizeonce: undesired resize from %d x %d to %d x %d\n", numrows, numcols, n, m);
if (!this->empty() && (n != this->numrows || m != this->numcols))
fprintf (stderr, "resizeonce: undesired resize from %d x %d to %d x %d\n", this->numrows, this->numcols, n, m);
resize (n, m);
#else
if (empty())
@ -1349,10 +1364,10 @@ public:
// non-destructive resize() to a smaller size
void shrink(size_t newrows, size_t newcols)
{
if (newrows > numrows || newcols > numcols)
if (newrows > this->numrows || newcols > this->numcols)
throw std::logic_error ("shrink: attempted to grow the matrix");
numrows = newrows;
numcols = newcols;
this->numrows = newrows;
this->numcols = newcols;
}
// file I/O
@ -1360,8 +1375,8 @@ public:
{
fputTag (f, "BMAT");
fputstring (f, name);
fputint (f, (int) numrows);
fputint (f, (int) numcols);
fputint (f, (int) this->numrows);
fputint (f, (int) this->numcols);
const auto & us = *this;
foreach_column (j, us)
{
@ -1375,8 +1390,8 @@ public:
{
fputTag(f, "BMAT");
fputstring (f, name);
fputint (f, (int) numrows);
fputint (f, (int) numcols);
fputint (f, (int) this->numrows);
fputint (f, (int) this->numcols);
const auto & us = *this;
foreach_column (j, us)
{
@ -1426,9 +1441,9 @@ public:
}
// paging support (used in feature source)
void topagefile (FILE * f) const { if (!empty()) fwriteOrDie (p, sizeinpagefile(), 1, f); }
void frompagefile (FILE * f) { if (!empty()) freadOrDie (p, sizeinpagefile(), 1, f); }
size_t sizeinpagefile() const { return colstride * numcols * sizeof (*p); }
void topagefile (FILE * f) const { if (!this->empty()) fwriteOrDie (this->p, sizeinpagefile(), 1, f); }
void frompagefile (FILE * f) { if (!this->empty()) freadOrDie (this->p, sizeinpagefile(), 1, f); }
size_t sizeinpagefile() const { return this->colstride * this->numcols * sizeof (*(this->p)); }
// getting a one-column sub-view on this
ssematrixstriperef<ssematrixbase> col (size_t j)
@ -1541,7 +1556,7 @@ template<class M> pair<unsigned int,unsigned int> printmatvaluedistributionf (co
const size_t numparts = 100;
for (size_t i=1; i<=numparts; i++)
{
fprintf (stderr, "%.5f%% absolute values are under %.10f\n", i*100.0/numparts, vals[min(num-1,i*num/numparts)]);
fprintf (stderr, "%.5f%% absolute values are under %.10f\n", i*100.0/numparts, vals[min((size_t)num-1,i*num/numparts)]);
}
fprintf (stderr, "\n%.5f%% values are zero\n\n", 100.0*numzeros/num);
#endif

Просмотреть файл

@ -12,9 +12,9 @@
#include "Platform.h"
#define _CRT_SECURE_NO_WARNINGS // "secure" CRT not available on all platforms
#include "targetver.h"
#ifndef __unix__
#include "targetver.h"
#define WIN32_LEAN_AND_MEAN // Exclude rarely-used stuff from Windows headers
// Windows Header Files:
#include <windows.h>

Просмотреть файл

@ -113,7 +113,7 @@ class minibatchutterancesource : public minibatchsource
if (featdim == 0)
{
reader.getinfo (utteranceset[0].parsedpath, featkind, featdim, sampperiod);
fprintf (stderr, "requiredata: determined feature kind as %d-dimensional '%s' with frame shift %.1f ms\n", featdim, featkind.c_str(), sampperiod / 1e4);
fprintf (stderr, "requiredata: determined feature kind as %d-dimensional '%s' with frame shift %.1f ms\n", (int)featdim, featkind.c_str(), sampperiod / 1e4);
}
// read all utterances; if they are in the same archive, htkfeatreader will be efficient in not closing the file
frames.resize (featdim, totalframes);
@ -130,7 +130,7 @@ class minibatchutterancesource : public minibatchsource
latticesource.getlattices (utteranceset[i].key(), lattices[i], uttframes.cols());
}
//fprintf (stderr, "\n");
fprintf (stderr, "requiredata: %d utterances read\n", utteranceset.size());
fprintf (stderr, "requiredata: %d utterances read\n", (int)utteranceset.size());
}
catch (...)
{
@ -199,17 +199,17 @@ class minibatchutterancesource : public minibatchsource
}
};
std::vector<utteranceref> randomizedutterancerefs; // [pos] randomized utterance ids
std::hash_map<size_t,size_t> randomizedutteranceposmap; // [globalts] -> pos lookup table
std::unordered_map<size_t, size_t> randomizedutteranceposmap; // [globalts] -> pos lookup table
struct positionchunkwindow // chunk window required in memory when at a certain position, for controlling paging
{
std::vector<chunk>::const_iterator definingchunk; // the chunk in randomizedchunks[] that defined the utterance position of this utterance
std::vector<chunk>::iterator definingchunk; // the chunk in randomizedchunks[] that defined the utterance position of this utterance
size_t windowbegin() const { return definingchunk->windowbegin; }
size_t windowend() const { return definingchunk->windowend; }
bool isvalidforthisposition (const utteranceref & utt) const
{
return utt.chunkindex >= windowbegin() && utt.chunkindex < windowend(); // check if 'utt' lives in is in allowed range for this position
}
positionchunkwindow (std::vector<chunk>::const_iterator definingchunk) : definingchunk (definingchunk) {}
positionchunkwindow (std::vector<chunk>::iterator definingchunk) : definingchunk (definingchunk) {}
};
std::vector<positionchunkwindow> positionchunkwindows; // [utterance position] -> [windowbegin, windowend) for controlling paging
@ -297,7 +297,7 @@ public:
throw std::runtime_error ("minibatchutterancesource: utterances < 2 frames not supported");
if (uttframes > frameref::maxframesperutterance)
{
fprintf (stderr, "minibatchutterancesource: skipping %d-th file (%d frames) because it exceeds max. frames (%d) for frameref bit field: %S", i, uttframes, frameref::maxframesperutterance, key.c_str());
fprintf (stderr, "minibatchutterancesource: skipping %d-th file (%d frames) because it exceeds max. frames (%d) for frameref bit field: %S", i, (int)uttframes, (int)frameref::maxframesperutterance, key.c_str());
continue;
}
@ -331,7 +331,7 @@ public:
size_t labframes = labseq.empty() ? 0 : (labseq[labseq.size()-1].firstframe + labseq[labseq.size()-1].numframes);
if (labframes != uttframes)
{
fprintf (stderr, " [duration mismatch (%d in label vs. %d in feat file), skipping %S]", labframes, uttframes, key.c_str());
fprintf (stderr, " [duration mismatch (%d in label vs. %d in feat file), skipping %S]", (int)labframes, (int)uttframes, key.c_str());
nomlf++;
continue; // skip this utterance at all
}
@ -347,7 +347,7 @@ public:
throw std::runtime_error ("CLASSIDTYPE has too few bits");
for (size_t t = e.firstframe; t < e.firstframe + e.numframes; t++)
classids.push_back ((CLASSIDTYPE) e.classid);
numclasses = max (numclasses, 1u + e.classid);
numclasses = max (numclasses, (size_t)(1u + e.classid));
counts.resize (numclasses, 0);
counts[e.classid] += e.numframes;
}
@ -360,7 +360,7 @@ public:
throw std::logic_error (msra::strfun::strprintf ("minibatchutterancesource: label duration inconsistent with feature file in MLF label set: %S", key.c_str()));
assert (labels.empty() || classids.size() == _totalframes + utteranceset.size());
}
fprintf (stderr, " %d frames in %d out of %d utterances; %d classes\n", _totalframes, utteranceset.size(),infiles.size(), numclasses);
fprintf (stderr, " %d frames in %d out of %d utterances; %d classes\n", (int)_totalframes, (int)utteranceset.size(), (int)infiles.size(), (int)numclasses);
if (!labels.empty())
foreach_index (i, utteranceset)
{
@ -369,7 +369,7 @@ public:
}
if (nomlf + nolat > 0)
{
fprintf (stderr, "minibatchutterancesource: out of %d files, %d files not found in label set and %d have no lattice\n", infiles.size(), nomlf, nolat);
fprintf (stderr, "minibatchutterancesource: out of %d files, %d files not found in label set and %d have no lattice\n", (int)infiles.size(), (int)nomlf, (int)nolat);
if (nomlf + nolat > infiles.size() / 2)
throw std::runtime_error ("minibatchutterancesource: too many files not found in label set--assuming broken configuration\n");
}
@ -398,7 +398,7 @@ public:
}
numutterances = utteranceset.size();
fprintf (stderr, "minibatchutterancesource: %d utterances grouped into %d chunks, av. chunk size: %.1f utterances, %.1f frames\n",
numutterances, allchunks.size(), numutterances / (double) allchunks.size(), _totalframes / (double) allchunks.size());
(int)numutterances, (int)allchunks.size(), numutterances / (double) allchunks.size(), _totalframes / (double) allchunks.size());
// Now utterances are stored exclusively in allchunks[]. They are never referred to by a sequential utterance id at this point, only by chunk/within-chunk index.
// preliminary mem allocation for frame references (if in frame mode)
@ -462,7 +462,7 @@ private:
return sweep;
currentsweep = sweep;
fprintf (stderr, "lazyrandomization: re-randomizing for sweep %d in %s mode\n", currentsweep, framemode ? "frame" : "utterance");
fprintf (stderr, "lazyrandomization: re-randomizing for sweep %d in %s mode\n", (int)currentsweep, framemode ? "frame" : "utterance");
const size_t sweepts = sweep * _totalframes; // first global frame index for this sweep
@ -751,7 +751,7 @@ private:
if (verbosity)
fprintf (stderr, "releaserandomizedchunk: paging out randomized chunk %d (frame range [%d..%d]), %d resident in RAM\n",
k, randomizedchunks[k].globalts, randomizedchunks[k].globalte()-1, chunksinram-1);
(int)k, (int)randomizedchunks[k].globalts, (int)(randomizedchunks[k].globalte()-1), (int)(chunksinram-1));
chunkdata.releasedata();
chunksinram--;
}
@ -770,7 +770,7 @@ private:
return false;
if (verbosity)
fprintf (stderr, "requirerandomizedchunk: paging in randomized chunk %d (frame range [%d..%d]), %d resident in RAM\n", chunkindex, chunk.globalts, chunk.globalte()-1, chunksinram+1);
fprintf (stderr, "requirerandomizedchunk: paging in randomized chunk %d (frame range [%d..%d]), %d resident in RAM\n", (int)chunkindex, (int)chunk.globalts, (int)(chunk.globalte()-1), (int)(chunksinram+1));
msra::util::attempt (5, [&]() // (reading from network)
{
chunkdata.requiredata (featkind, featdim, sampperiod, this->lattices);
@ -861,7 +861,7 @@ public:
// return these utterances
if (verbosity > 0)
fprintf (stderr, "getbatch: getting utterances %d..%d (%d frames out of %d requested) in sweep %d\n", spos, epos -1, mbframes, framesrequested, sweep);
fprintf (stderr, "getbatch: getting utterances %d..%d (%d frames out of %d requested) in sweep %d\n", (int)spos, (int)(epos -1), (int)mbframes, (int)framesrequested, (int)sweep);
size_t tspos = 0; // relative start of utterance 'pos' within the returned minibatch
for (size_t pos = spos; pos < epos; pos++)
{
@ -927,7 +927,7 @@ public:
const size_t windowend = randomizedchunks[lastchunk].windowend;
if (verbosity > 0)
fprintf (stderr, "getbatch: getting randomized frames [%d..%d] (%d frames out of %d requested) in sweep %d; chunks [%d..%d] -> chunk window [%d..%d)\n",
globalts, globalte, mbframes, framesrequested, sweep, firstchunk, lastchunk, windowbegin, windowend);
(int)globalts, (int)globalte, (int)mbframes, (int)framesrequested, (int)sweep, (int)firstchunk, (int)lastchunk, (int)windowbegin, (int)windowend);
// release all data outside, and page in all data inside
for (size_t k = 0; k < windowbegin; k++)
releaserandomizedchunk (k);

Просмотреть файл

@ -54,9 +54,14 @@ class minibatchutterancesourcemulti : public minibatchsource
size_t numframes() const { return parsedpath.numframes(); }
const wstring key() const // key used for looking up lattice (not stored to save space)
{
#ifdef _WIN32
static const wstring emptywstring;
static const wregex deleteextensionre (L"\\.[^\\.\\\\/:]*$");
return regex_replace (logicalpath(), deleteextensionre, emptywstring); // delete extension (or not if none)
#endif
#ifdef __unix__
return removeExtension(basename(logicalpath()));
#endif
}
};
struct utterancechunkdata // data for a chunk of utterances
@ -116,7 +121,7 @@ class minibatchutterancesourcemulti : public minibatchsource
if (featdim == 0)
{
reader.getinfo (utteranceset[0].parsedpath, featkind, featdim, sampperiod);
fprintf (stderr, "requiredata: determined feature kind as %d-dimensional '%s' with frame shift %.1f ms\n", featdim, featkind.c_str(), sampperiod / 1e4);
fprintf (stderr, "requiredata: determined feature kind as %d-dimensional '%s' with frame shift %.1f ms\n", (int)featdim, featkind.c_str(), sampperiod / 1e4);
}
// read all utterances; if they are in the same archive, htkfeatreader will be efficient in not closing the file
frames.resize (featdim, totalframes);
@ -134,7 +139,7 @@ class minibatchutterancesourcemulti : public minibatchsource
}
//fprintf (stderr, "\n");
if (verbosity)
fprintf (stderr, "requiredata: %d utterances read\n", utteranceset.size());
fprintf (stderr, "requiredata: %d utterances read\n", (int)utteranceset.size());
}
catch (...)
{
@ -203,41 +208,30 @@ class minibatchutterancesourcemulti : public minibatchsource
}
};
std::vector<utteranceref> randomizedutterancerefs; // [pos] randomized utterance ids
std::hash_map<size_t,size_t> randomizedutteranceposmap; // [globalts] -> pos lookup table
std::unordered_map<size_t, size_t> randomizedutteranceposmap; // [globalts] -> pos lookup table
struct positionchunkwindow // chunk window required in memory when at a certain position, for controlling paging
{
std::vector<chunk>::const_iterator definingchunk; // the chunk in randomizedchunks[] that defined the utterance position of this utterance
std::vector<chunk>::iterator definingchunk; // the chunk in randomizedchunks[] that defined the utterance position of this utterance
size_t windowbegin() const { return definingchunk->windowbegin; }
size_t windowend() const { return definingchunk->windowend; }
bool isvalidforthisposition (const utteranceref & utt) const
{
return utt.chunkindex >= windowbegin() && utt.chunkindex < windowend(); // check if 'utt' lives in is in allowed range for this position
}
positionchunkwindow (std::vector<chunk>::const_iterator definingchunk) : definingchunk (definingchunk) {}
positionchunkwindow (std::vector<chunk>::iterator definingchunk) : definingchunk (definingchunk) {}
};
std::vector<positionchunkwindow> positionchunkwindows; // [utterance position] -> [windowbegin, windowend) for controlling paging
// frame-level randomization layered on top of utterance chunking (randomized, where randomization is cached)
struct frameref
{
#ifdef _WIN64 // (sadly, the compiler makes this 8 bytes, not 6)
unsigned short chunkindex; // lives in this chunk (index into randomizedchunks[])
unsigned short utteranceindex; // utterance index in that chunk
static const size_t maxutterancesperchunk = 65535;
unsigned short frameindex; // frame index within the utterance
static const size_t maxframesperutterance = 65535;
#else // For Win32, we care to keep it inside 32 bits. We have already encountered setups where that's not enough.
unsigned int chunkindex : 13; // lives in this chunk (index into randomizedchunks[])
unsigned int utteranceindex : 8; // utterance index in that chunk
static const size_t maxutterancesperchunk = 255;
unsigned int frameindex : 11; // frame index within the utterance
static const size_t maxframesperutterance = 2047;
#endif
frameref (size_t ci, size_t ui, size_t fi) : chunkindex ((unsigned short) ci), utteranceindex ((unsigned short) ui), frameindex ((unsigned short) fi)
{
#ifndef _WIN64
static_assert (sizeof (frameref) == 4, "frameref: bit fields too large to fit into 32-bit integer");
#endif
if (ci == chunkindex && ui == utteranceindex && fi == frameindex)
return;
throw std::logic_error ("frameref: bit fields too small");
@ -334,8 +328,8 @@ public:
// first check consistency across feature streams
// We'll go through the SCP files for each stream to make sure the duration is consistent
// If not, we'll plan to ignore the utterance, and inform the user
// m indexes the feature stream
// i indexes the files within a stream, i.e. in the SCP file)
// m indexes the feature stream
// i indexes the files within a stream, i.e. in the SCP file)
foreach_index(m, infiles){
if (m == 0){
numutts = infiles[m].size();
@ -353,7 +347,7 @@ public:
throw std::runtime_error("minibatchutterancesource: utterances < 2 frames not supported");
if (uttframes > frameref::maxframesperutterance)
{
fprintf(stderr, "minibatchutterancesource: skipping %d-th file (%d frames) because it exceeds max. frames (%d) for frameref bit field: %S\n", i, uttframes, frameref::maxframesperutterance, key.c_str());
fprintf(stderr, "minibatchutterancesource: skipping %d-th file (%d frames) because it exceeds max. frames (%d) for frameref bit field: %S\n", i, (int)uttframes, (int)frameref::maxframesperutterance, key.c_str());
uttduration[i] = 0;
uttisvalid[i] = false;
}
@ -363,7 +357,7 @@ public:
uttisvalid[i] = true;
}
else if (uttduration[i] != uttframes){
fprintf(stderr, "minibatchutterancesource: skipping %d-th file due to inconsistency in duration in different feature streams (%d vs %d frames)\n", i, uttduration[i], uttframes);
fprintf(stderr, "minibatchutterancesource: skipping %d-th file due to inconsistency in duration in different feature streams (%d vs %d frames)\n", i, (int)uttduration[i], (int)uttframes);
uttduration[i] = 0;
uttisvalid[i] = false;
}
@ -378,7 +372,7 @@ public:
if (invalidutts > uttisvalid.size() / 2)
throw std::runtime_error("minibatchutterancesource: too many files with inconsistent durations, assuming broken configuration\n");
else if (invalidutts>0)
fprintf(stderr, "Found inconsistent durations across feature streams in %d out of %d files\n", invalidutts, uttisvalid.size());
fprintf(stderr, "Found inconsistent durations across feature streams in %d out of %d files\n", (int)invalidutts, (int)uttisvalid.size());
// now process the features and labels
@ -459,7 +453,7 @@ public:
size_t labframes = labseq.empty() ? 0 : (labseq[labseq.size()-1].firstframe + labseq[labseq.size()-1].numframes);
if (labframes != uttframes)
{
fprintf (stderr, " [duration mismatch (%d in label vs. %d in feat file), skipping %S]", labframes, uttframes, key.c_str());
fprintf (stderr, " [duration mismatch (%d in label vs. %d in feat file), skipping %S]", (int)labframes, (int)uttframes, key.c_str());
nomlf++;
uttisvalid[i] = false;
//continue; // skip this utterance at all
@ -484,13 +478,13 @@ public:
}
if (e.classid >= udim[j])
{
throw std::runtime_error(msra::strfun::strprintf("minibatchutterancesource: class id %d exceeds model output dimension %d in file %S", e.classid, udim, key.c_str()));
throw std::runtime_error(msra::strfun::strprintf("minibatchutterancesource: class id %d exceeds model output dimension %d in file %S", e.classid, udim[j], key.c_str()));
}
if (e.classid != (CLASSIDTYPE) e.classid)
throw std::runtime_error ("CLASSIDTYPE has too few bits");
for (size_t t = e.firstframe; t < e.firstframe + e.numframes; t++)
classids[j]->push_back ((CLASSIDTYPE) e.classid);
numclasses[j] = max (numclasses[j], 1u + e.classid);
numclasses[j] = max (numclasses[j], (size_t)(1u + e.classid));
counts[j].resize (numclasses[j], 0);
counts[j][e.classid] += e.numframes;
}
@ -521,7 +515,7 @@ public:
else
assert(utteranceset.size() == utterancesetsize);
fprintf (stderr, "feature set %d: %d frames in %d out of %d utterances\n", m, _totalframes, utteranceset.size(),infiles[m].size());
fprintf (stderr, "feature set %d: %d frames in %d out of %d utterances\n", m, (int)_totalframes, (int)utteranceset.size(), (int)infiles[m].size());
if (!labels.empty()){
foreach_index (j, labels){
@ -538,11 +532,11 @@ public:
}
if (nomlf + nolat > 0)
{
fprintf (stderr, "minibatchutterancesource: out of %d files, %d files not found in label set and %d have no lattice\n", infiles[0].size(), nomlf, nolat);
fprintf (stderr, "minibatchutterancesource: out of %d files, %d files not found in label set and %d have no lattice\n", (int)infiles[0].size(), (int)nomlf, (int)nolat);
if (nomlf + nolat > infiles[m].size() / 2)
throw std::runtime_error ("minibatchutterancesource: too many files not found in label set--assuming broken configuration\n");
}
if (m==0) {foreach_index(j, numclasses) { fprintf(stderr,"label set %d: %d classes\n",j, numclasses[j]); } }
if (m==0) {foreach_index(j, numclasses) { fprintf(stderr,"label set %d: %d classes\n", j, (int)numclasses[j]); } }
// distribute them over chunks
// We simply count off frames until we reach the chunk size.
// Note that we first randomize the chunks, i.e. when used, chunks are non-consecutive and thus cause the disk head to seek for each chunk.
@ -568,7 +562,7 @@ public:
}
numutterances = utteranceset.size();
fprintf (stderr, "minibatchutterancesource: %d utterances grouped into %d chunks, av. chunk size: %.1f utterances, %.1f frames\n",
numutterances, thisallchunks.size(), numutterances / (double) thisallchunks.size(), _totalframes / (double) thisallchunks.size());
(int)numutterances, (int)thisallchunks.size(), numutterances / (double) thisallchunks.size(), _totalframes / (double) thisallchunks.size());
// Now utterances are stored exclusively in allchunks[]. They are never referred to by a sequential utterance id at this point, only by chunk/within-chunk index.
}
// preliminary mem allocation for frame references (if in frame mode)
@ -657,7 +651,7 @@ private:
currentsweep = sweep;
if (verbosity>0)
fprintf (stderr, "lazyrandomization: re-randomizing for sweep %d in %s mode\n", currentsweep, framemode ? "frame" : "utterance");
fprintf (stderr, "lazyrandomization: re-randomizing for sweep %d in %s mode\n", (int)currentsweep, framemode ? "frame" : "utterance");
const size_t sweepts = sweep * _totalframes; // first global frame index for this sweep
@ -968,7 +962,7 @@ private:
{
if (verbosity)
fprintf (stderr, "releaserandomizedchunk: paging out randomized chunk %d (frame range [%d..%d]), %d resident in RAM\n",
k, randomizedchunks[m][k].globalts, randomizedchunks[m][k].globalte()-1, chunksinram-1);
(int)k, (int)randomizedchunks[m][k].globalts, (int)(randomizedchunks[m][k].globalte()-1), (int)(chunksinram-1));
chunkdata.releasedata();
numreleased++;
}
@ -1010,7 +1004,7 @@ private:
auto & chunk = randomizedchunks[m][chunkindex];
auto & chunkdata = chunk.getchunkdata();
if (verbosity)
fprintf (stderr, "feature set %d: requirerandomizedchunk: paging in randomized chunk %d (frame range [%d..%d]), %d resident in RAM\n", m, chunkindex, chunk.globalts, chunk.globalte()-1, chunksinram+1);
fprintf (stderr, "feature set %d: requirerandomizedchunk: paging in randomized chunk %d (frame range [%d..%d]), %d resident in RAM\n", m, (int)chunkindex, (int)chunk.globalts, (int)(chunk.globalte()-1), (int)(chunksinram+1));
msra::util::attempt (5, [&]() // (reading from network)
{
chunkdata.requiredata (featkind[m], featdim[m], sampperiod[m], this->lattices, verbosity);
@ -1154,7 +1148,7 @@ public:
}
// return these utterances
if (verbosity > 0)
fprintf(stderr, "getbatch: getting utterances %d..%d (%d subset of %d frames out of %d requested) in sweep %d\n", spos, epos - 1, tspos, mbframes, framesrequested, sweep);
fprintf(stderr, "getbatch: getting utterances %d..%d (%d subset of %d frames out of %d requested) in sweep %d\n", (int)spos, (int)(epos - 1), (int)tspos, (int)mbframes, (int)framesrequested, (int)sweep);
tspos = 0; // relative start of utterance 'pos' within the returned minibatch
for (size_t pos = spos; pos < epos; pos++)
{
@ -1239,9 +1233,9 @@ public:
const size_t lastchunk = chunkforframepos (globalte-1);
const size_t windowbegin = randomizedchunks[0][firstchunk].windowbegin;
const size_t windowend = randomizedchunks[0][lastchunk].windowend;
if (verbosity)
if (verbosity > 0)
fprintf (stderr, "getbatch: getting randomized frames [%d..%d] (%d frames out of %d requested) in sweep %d; chunks [%d..%d] -> chunk window [%d..%d)\n",
globalts, globalte, mbframes, framesrequested, sweep, firstchunk, lastchunk, windowbegin, windowend);
(int)globalts, (int)globalte, (int)mbframes, (int)framesrequested, (int)sweep, (int)firstchunk, (int)lastchunk, (int)windowbegin, (int)windowend);
// release all data outside, and page in all data inside
for (size_t k = 0; k < windowbegin; k++)
releaserandomizedchunk (k);

Просмотреть файл

@ -1,63 +0,0 @@
//
// <copyright file="DataReader.cpp" company="Microsoft">
// Copyright (c) Microsoft Corporation. All rights reserved.
// </copyright>
//
// DataReader.cpp : Defines the exported functions for the DLL application.
//
#include "stdafx.h"
#include "basetypes.h"
#include "htkfeatio.h" // for reading HTK features
//#include "latticearchive.h" // for reading HTK phoneme lattices (MMI training)
#include "simplesenonehmm.h" // for MMI scoring
//#include "msra_mgram.h" // for unigram scores of ground-truth path in sequence training
#include "rollingwindowsource.h" // minibatch sources
#include "utterancesource.h"
//#include "readaheadsource.h"
#include "chunkevalsource.h"
#define DATAREADER_EXPORTS
#include "DataReader.h"
#include "HTKMLFReader.h"
#include "commandArgUtil.h"
namespace Microsoft { namespace MSR { namespace CNTK {
template<class ElemType>
void DATAREADER_API GetReader(IDataReader<ElemType>** preader)
{
*preader = new HTKMLFReader<ElemType>();
}
extern "C" DATAREADER_API void GetReaderF(IDataReader<float>** preader)
{
GetReader(preader);
}
extern "C" DATAREADER_API void GetReaderD(IDataReader<double>** preader)
{
GetReader(preader);
}
// Utility function, in ConfigFile.cpp, but HTKMLFReader doesn't need that code...
// Trim - trim white space off the start and end of the string
// str - string to trim
// NOTE: if the entire string is empty, then the string will be set to an empty string
/* void Trim(std::string& str)
{
auto found = str.find_first_not_of(" \t");
if (found == npos)
{
str.erase(0);
return;
}
str.erase(0, found);
found = str.find_last_not_of(" \t");
if (found != npos)
str.erase(found+1);
}*/
}}}

Просмотреть файл

@ -1,111 +0,0 @@
//
// <copyright file="DataWriter.cpp" company="Microsoft">
// Copyright (c) Microsoft Corporation. All rights reserved.
// </copyright>
//
// DataWriter.cpp : Defines the exported functions for the DLL application.
//
#include "stdafx.h"
#include "basetypes.h"
#include "htkfeatio.h" // for reading HTK features
#define DATAWRITER_EXPORTS
#include "DataWriter.h"
#include "HTKMLFWriter.h"
namespace Microsoft { namespace MSR { namespace CNTK {
template<class ElemType>
void DATAWRITER_API GetWriter(IDataWriter<ElemType>** pwriter)
{
*pwriter = new HTKMLFWriter<ElemType>();
}
extern "C" DATAWRITER_API void GetWriterF(IDataWriter<float>** pwriter)
{
GetWriter(pwriter);
}
extern "C" DATAWRITER_API void GetWriterD(IDataWriter<double>** pwriter)
{
GetWriter(pwriter);
}
template<class ElemType>
void DataWriter<ElemType>::Init(const ConfigParameters& writerConfig)
{
m_dataWriter = new HTKMLFWriter<ElemType>();
m_dataWriter->Init(writerConfig);
}
template<class ElemType>
void DataWriter<ElemType>::GetDataWriter(const ConfigParameters& /*config*/)
{
NOT_IMPLEMENTED;
}
// Destroy - cleanup and remove this class
// NOTE: this destroys the object, and it can't be used past this point
template<class ElemType>
void DataWriter<ElemType>::Destroy()
{
delete m_dataWriter;
m_dataWriter = NULL;
}
// DataWriter Constructor
// config - [in] configuration data for the data writer
template<class ElemType>
DataWriter<ElemType>::DataWriter(const ConfigParameters& config)
{
Init(config);
}
// destructor - cleanup temp files, etc.
template<class ElemType>
DataWriter<ElemType>::~DataWriter()
{
delete m_dataWriter;
m_dataWriter = NULL;
}
// GetSections - Get the sections of the file
// sections - a map of section name to section. Data sepcifications from config file will be used to determine where and how to save data
template<class ElemType>
void DataWriter<ElemType>::GetSections(std::map<std::wstring, SectionType, nocase_compare>& sections)
{
m_dataWriter->GetSections(sections);
}
// SaveData - save data in the file/files
// recordStart - Starting record number
// matricies - a map of section name (section:subsection) to data pointer. Data sepcifications from config file will be used to determine where and how to save data
// numRecords - number of records we are saving, can be zero if not applicable
// datasetSize - Size of the dataset
// byteVariableSized - for variable sized data, size of current block to be written, zero when not used, or ignored if not variable sized data
template<class ElemType>
bool DataWriter<ElemType>::SaveData(size_t recordStart, const std::map<std::wstring, void*, nocase_compare>& matrices, size_t numRecords, size_t datasetSize, size_t byteVariableSized)
{
return m_dataWriter->SaveData(recordStart, matrices, numRecords, datasetSize, byteVariableSized);
}
// SaveMapping - save a map into the file
// saveId - name of the section to save into (section:subsection format)
// labelMapping - map we are saving to the file
template<class ElemType>
void DataWriter<ElemType>::SaveMapping(std::wstring saveId, const std::map<LabelIdType, LabelType>& labelMapping)
{
m_dataWriter->SaveMapping(saveId, labelMapping);
}
//The explicit instantiation
template class DataWriter<double>;
template class DataWriter<float>;
}}}

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -1,202 +0,0 @@
//
// <copyright file="HTKMLFReader.h" company="Microsoft">
// Copyright (c) Microsoft Corporation. All rights reserved.
// </copyright>
//
// HTKMLFReader.h - Include file for the MTK and MLF format of features and samples
#pragma once
#include "DataReader.h"
#include "commandArgUtil.h" // for intargvector
#include "CUDAPageLockedMemAllocator.h"
namespace Microsoft { namespace MSR { namespace CNTK {
template<class ElemType>
class HTKMLFReader : public IDataReader<ElemType>
{
private:
const static size_t m_htkRandomizeAuto = 0;
const static size_t m_htkRandomizeDisable = (size_t)-1;
msra::dbn::minibatchiterator* m_mbiter;
msra::dbn::minibatchsource* m_frameSource;
//msra::dbn::minibatchreadaheadsource* m_readAheadSource;
msra::dbn::FileEvalSource* m_fileEvalSource;
msra::dbn::latticesource* m_lattices;
map<wstring,msra::lattices::lattice::htkmlfwordsequence> m_latticeMap;
vector<bool> m_sentenceEnd;
bool m_readAhead;
bool m_truncated;
bool m_framemode;
vector<size_t> m_processedFrame;
intargvector m_numberOfuttsPerMinibatchForAllEpochs;
size_t m_numberOfuttsPerMinibatch;
size_t m_actualnumberOfuttsPerMinibatch;
size_t m_mbSize;
vector<size_t> m_toProcess;
vector<size_t> m_switchFrame;
bool m_noData;
bool m_trainOrTest; // if false, in file writing mode
using LabelType = typename IDataReader<ElemType>::LabelType;
using LabelIdType = typename IDataReader<ElemType>::LabelIdType;
std::map<LabelIdType, LabelType> m_idToLabelMap;
bool m_partialMinibatch; // allow partial minibatches?
std::vector<ElemType*> m_featuresBufferMultiUtt;
std::vector<size_t> m_featuresBufferAllocatedMultiUtt;
std::vector<ElemType*> m_labelsBufferMultiUtt;
std::vector<size_t> m_labelsBufferAllocatedMultiUtt;
std::vector<size_t> m_featuresStartIndexMultiUtt;
std::vector<size_t> m_labelsStartIndexMultiUtt;
CUDAPageLockedMemAllocator* m_cudaAllocator;
std::vector<std::shared_ptr<ElemType>> m_featuresBufferMultiIO;
std::vector<size_t> m_featuresBufferAllocatedMultiIO;
std::vector<std::shared_ptr<ElemType>> m_labelsBufferMultiIO;
std::vector<size_t> m_labelsBufferAllocatedMultiIO;
std::map<std::wstring,size_t> m_featureNameToIdMap;
std::map<std::wstring,size_t> m_labelNameToIdMap;
std::map<std::wstring,size_t> m_nameToTypeMap;
std::map<std::wstring,size_t> m_featureNameToDimMap;
std::map<std::wstring,size_t> m_labelNameToDimMap;
// for writing outputs to files (standard single input/output network) - deprecate eventually
bool m_checkDictionaryKeys;
bool m_convertLabelsToTargets;
std::vector <bool> m_convertLabelsToTargetsMultiIO;
std::vector<std::vector<std::wstring>> m_inputFilesMultiIO;
size_t m_inputFileIndex;
std::vector<size_t> m_featDims;
std::vector<size_t> m_labelDims;
std::vector<std::vector<std::vector<ElemType>>>m_labelToTargetMapMultiIO;
void PrepareForTrainingOrTesting(const ConfigParameters& config);
void PrepareForWriting(const ConfigParameters& config);
bool GetMinibatchToTrainOrTest(std::map<std::wstring, Matrix<ElemType>*>&matrices);
bool GetMinibatchToWrite(std::map<std::wstring, Matrix<ElemType>*>&matrices);
void StartMinibatchLoopToTrainOrTest(size_t mbSize, size_t epoch, size_t subsetNum, size_t numSubsets, size_t requestedEpochSamples = requestDataSize);
void StartMinibatchLoopToWrite(size_t mbSize, size_t epoch, size_t requestedEpochSamples=requestDataSize);
bool ReNewBufferForMultiIO(size_t i);
size_t NumberSlicesInEachRecurrentIter() { return m_numberOfuttsPerMinibatch ;}
void SetNbrSlicesEachRecurrentIter(const size_t) { };
void GetDataNamesFromConfig(const ConfigParameters& readerConfig, std::vector<std::wstring>& features, std::vector<std::wstring>& labels);
size_t ReadLabelToTargetMappingFile (const std::wstring& labelToTargetMappingFile, const std::wstring& labelListFile, std::vector<std::vector<ElemType>>& labelToTargetMap);
void ExpandDotDotDot(wstring & featPath, const wstring & scpPath, wstring & scpDirCached);
enum InputOutputTypes
{
real,
category,
};
private:
CUDAPageLockedMemAllocator* GetCUDAAllocator(int deviceID)
{
if (m_cudaAllocator != nullptr)
{
if (m_cudaAllocator->GetDeviceID() != deviceID)
{
delete m_cudaAllocator;
m_cudaAllocator = nullptr;
}
}
if (m_cudaAllocator == nullptr)
{
m_cudaAllocator = new CUDAPageLockedMemAllocator(deviceID);
}
return m_cudaAllocator;
}
std::shared_ptr<ElemType> AllocateIntermediateBuffer(int deviceID, size_t numElements)
{
if (deviceID >= 0)
{
// Use pinned memory for GPU devices for better copy performance
size_t totalSize = sizeof(ElemType) * numElements;
return std::shared_ptr<ElemType>((ElemType*)GetCUDAAllocator(deviceID)->Malloc(totalSize), [this, deviceID](ElemType* p) {
this->GetCUDAAllocator(deviceID)->Free((char*)p);
});
}
else
{
return std::shared_ptr<ElemType>(new ElemType[numElements], [](ElemType* p) {
delete[] p;
});
}
}
public:
/// a matrix of n_stream x n_length
/// n_stream is the number of streams
/// n_length is the maximum lenght of each stream
/// for example, two sentences used in parallel in one minibatch would be
/// [2 x 5] if the max length of one of the sentences is 5
/// the elements of the matrix is 0, 1, or -1, defined as SEQUENCE_START, SEQUENCE_MIDDLE, NO_INPUT in cbasetype.h
/// 0 1 1 0 1
/// 1 0 1 0 0
/// for two parallel data streams. The first has two sentences, with 0 indicating begining of a sentence
/// the second data stream has two sentences, with 0 indicating begining of sentences
/// you may use 1 even if a sentence begins at that position, in this case, the trainer will carry over hidden states to the following
/// frame.
Matrix<ElemType> m_sentenceBegin;
/// a matrix of 1 x n_length
/// 1 denotes the case that there exists sentnece begin or no_labels case in this frame
/// 0 denotes such case is not in this frame
vector<MinibatchPackingFlag> m_minibatchPackingFlag;
/// by default it is false
/// if true, reader will set to SEQUENCE_MIDDLE for time positions that are orignally correspond to SEQUENCE_START
/// set to true so that a current minibatch can uses state activities from the previous minibatch.
/// default will have truncated BPTT, which only does BPTT inside a minibatch
bool mIgnoreSentenceBeginTag;
HTKMLFReader() : m_sentenceBegin(CPUDEVICE) {
}
virtual void Init(const ConfigParameters& config);
virtual void Destroy() {delete this;}
virtual ~HTKMLFReader();
virtual void StartMinibatchLoop(size_t mbSize, size_t epoch, size_t requestedEpochSamples = requestDataSize)
{
return StartDistributedMinibatchLoop(mbSize, epoch, 0, 1, requestedEpochSamples);
}
virtual bool SupportsDistributedMBRead() const override
{
return m_frameSource->supportsbatchsubsetting();
}
virtual void StartDistributedMinibatchLoop(size_t mbSize, size_t epoch, size_t subsetNum, size_t numSubsets, size_t requestedEpochSamples = requestDataSize) override;
virtual bool GetMinibatch(std::map<std::wstring, Matrix<ElemType>*>& matrices);
virtual const std::map<LabelIdType, LabelType>& GetLabelMapping(const std::wstring& sectionName);
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, LabelType>& labelMapping);
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart=0);
virtual bool DataEnd(EndDataType endDataType);
void SetSentenceEndInBatch(vector<size_t> &/*sentenceEnd*/);
void SetSentenceEnd(int /*actualMbSize*/){};
void SetSentenceSegBatch(Matrix<ElemType> &sentenceBegin, vector<MinibatchPackingFlag>& sentenceExistsBeginOrNoLabels);
bool RequireSentenceSeg() { return !m_framemode; };
};
}}}

Просмотреть файл

@ -1,184 +0,0 @@
//
// <copyright file="HTKMLFReader.cpp" company="Microsoft">
// Copyright (c) Microsoft Corporation. All rights reserved.
// </copyright>
//
// HTKMLFReader.cpp : Defines the exported functions for the DLL application.
//
#include "stdafx.h"
#include "basetypes.h"
#include "htkfeatio.h" // for reading HTK features
//#ifndef __unix__
#include "ssematrix.h"
//#endif
#define DATAWRITER_EXPORTS // creating the exports here
#include "DataWriter.h"
#include "commandArgUtil.h"
#include "HTKMLFWriter.h"
#ifdef LEAKDETECT
#include <vld.h> // for memory leak detection
#endif
namespace Microsoft { namespace MSR { namespace CNTK {
// Create a Data Writer
//DATAWRITER_API IDataWriter* DataWriterFactory(void)
template<class ElemType>
void HTKMLFWriter<ElemType>::Init(const ConfigParameters& writerConfig)
{
m_tempArray = nullptr;
m_tempArraySize = 0;
vector<wstring> scriptpaths;
vector<wstring> filelist;
size_t numFiles;
size_t firstfilesonly = SIZE_MAX; // set to a lower value for testing
ConfigArray outputNames = writerConfig("outputNodeNames","");
if (outputNames.size()<1)
RuntimeError("writer needs at least one outputNodeName specified in config");
foreach_index(i, outputNames) // inputNames should map to node names
{
ConfigParameters thisOutput = writerConfig(outputNames[i]);
if (thisOutput.Exists("dim"))
udims.push_back(thisOutput("dim"));
else
RuntimeError("HTKMLFWriter::Init: writer need to specify dim of output");
if (thisOutput.Exists("file"))
scriptpaths.push_back(thisOutput("file"));
else if (thisOutput.Exists("scpFile"))
scriptpaths.push_back(thisOutput("scpFile"));
else
RuntimeError("HTKMLFWriter::Init: writer needs to specify scpFile for output");
outputNameToIdMap[outputNames[i]]= i;
outputNameToDimMap[outputNames[i]]=udims[i];
wstring type = thisOutput("type","Real");
if (type == L"Real")
{
outputNameToTypeMap[outputNames[i]] = OutputTypes::outputReal;
}
else
{
throw std::runtime_error ("HTKMLFWriter::Init: output type for writer output expected to be Real");
}
}
numFiles=0;
foreach_index(i,scriptpaths)
{
filelist.clear();
std::wstring scriptPath = scriptpaths[i];
fprintf(stderr, "HTKMLFWriter::Init: reading output script file %S ...", scriptPath.c_str());
size_t n = 0;
for (msra::files::textreader reader(scriptPath); reader && filelist.size() <= firstfilesonly/*optimization*/; )
{
filelist.push_back (reader.wgetline());
n++;
}
fprintf (stderr, " %zu entries\n", n);
if (i==0)
numFiles=n;
else
if (n!=numFiles)
throw std::runtime_error (msra::strfun::strprintf ("HTKMLFWriter:Init: number of files in each scriptfile inconsistent (%d vs. %d)", numFiles,n));
outputFiles.push_back(filelist);
}
outputFileIndex=0;
sampPeriod=100000;
}
template<class ElemType>
void HTKMLFWriter<ElemType>::Destroy()
{
delete [] m_tempArray;
m_tempArray = nullptr;
m_tempArraySize = 0;
}
template<class ElemType>
void HTKMLFWriter<ElemType>::GetSections(std::map<std::wstring, SectionType, nocase_compare>& /*sections*/)
{
}
template<class ElemType>
bool HTKMLFWriter<ElemType>::SaveData(size_t /*recordStart*/, const std::map<std::wstring, void*, nocase_compare>& matrices, size_t /*numRecords*/, size_t /*datasetSize*/, size_t /*byteVariableSized*/)
{
//std::map<std::wstring, void*, nocase_compare>::iterator iter;
if (outputFileIndex>=outputFiles[0].size())
RuntimeError("index for output scp file out of range...");
for (auto iter = matrices.begin();iter!=matrices.end(); iter++)
{
wstring outputName = iter->first;
Matrix<ElemType>& outputData = *(static_cast<Matrix<ElemType>*>(iter->second));
size_t id = outputNameToIdMap[outputName];
size_t dim = outputNameToDimMap[outputName];
wstring outFile = outputFiles[id][outputFileIndex];
assert(outputData.GetNumRows()==dim); dim;
SaveToFile(outFile,outputData);
}
outputFileIndex++;
return true;
}
template<class ElemType>
void HTKMLFWriter<ElemType>::SaveToFile(std::wstring& outputFile, Matrix<ElemType>& outputData)
{
msra::dbn::matrix output;
output.resize(outputData.GetNumRows(),outputData.GetNumCols());
outputData.CopyToArray(m_tempArray, m_tempArraySize);
ElemType * pValue = m_tempArray;
for (int j=0; j< outputData.GetNumCols(); j++)
{
for (int i=0; i<outputData.GetNumRows(); i++)
{
output(i,j) = (float)*pValue++;
}
}
const size_t nansinf = output.countnaninf();
if (nansinf > 0)
fprintf (stderr, "chunkeval: %d NaNs or INF detected in '%S' (%d frames)\n", (int) nansinf, outputFile.c_str(), (int) output.cols());
// save it
msra::files::make_intermediate_dirs (outputFile);
msra::util::attempt (5, [&]()
{
msra::asr::htkfeatwriter::write (outputFile, "USER", this->sampPeriod, output);
});
fprintf (stderr, "evaluate: writing %zu frames of %S\n", output.cols(), outputFile.c_str());
}
template<class ElemType>
void HTKMLFWriter<ElemType>::SaveMapping(std::wstring saveId, const std::map<LabelIdType, LabelType>& /*labelMapping*/)
{
}
template class HTKMLFWriter<float>;
template class HTKMLFWriter<double>;
}}}

Просмотреть файл

@ -1,47 +0,0 @@
//
// <copyright file="HTKMLFReader.h" company="Microsoft">
// Copyright (c) Microsoft Corporation. All rights reserved.
// </copyright>
//
// HTKMLFReader.h - Include file for the MTK and MLF format of features and samples
#pragma once
#include "DataWriter.h"
#include <map>
#include <vector>
namespace Microsoft { namespace MSR { namespace CNTK {
template<class ElemType>
class HTKMLFWriter : public IDataWriter<ElemType>
{
private:
std::vector<size_t> outputDims;
std::vector<std::vector<std::wstring>> outputFiles;
std::vector<size_t> udims;
std::map<std::wstring,size_t> outputNameToIdMap;
std::map<std::wstring,size_t> outputNameToDimMap;
std::map<std::wstring,size_t> outputNameToTypeMap;
unsigned int sampPeriod;
size_t outputFileIndex;
void SaveToFile(std::wstring& outputFile, Matrix<ElemType>& outputData);
ElemType * m_tempArray;
size_t m_tempArraySize;
enum OutputTypes
{
outputReal,
outputCategory,
};
public:
using LabelType = typename IDataWriter<ElemType>::LabelType;
using LabelIdType = typename IDataWriter<ElemType>::LabelIdType;
virtual void Init(const ConfigParameters& writerConfig);
virtual void Destroy();
virtual void GetSections(std::map<std::wstring, SectionType, nocase_compare>& sections);
virtual bool SaveData(size_t recordStart, const std::map<std::wstring, void*, nocase_compare>& matrices, size_t numRecords, size_t datasetSize, size_t byteVariableSized);
virtual void SaveMapping(std::wstring saveId, const std::map<LabelIdType, LabelType>& labelMapping);
};
}}}

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -1,885 +0,0 @@
// TODO: This is a dup, we should get back to the shared one. But this one has some stuff the other doesn't.
//
// <copyright file="basetypes.h" company="Microsoft">
// Copyright (c) Microsoft Corporation. All rights reserved.
// </copyright>
//
#pragma once
#ifndef _BASETYPES_
#define _BASETYPES_
// [kit]: seems SECURE_SCL=0 doesn't work - causes crashes in release mode
// there are some complaints along this line on the web
// so disabled for now
//
//// we have agreed that _SECURE_SCL is disabled for release builds
//// it would be super dangerous to mix projects where this is inconsistent
//// this is one way to detect possible mismatches
//#ifdef NDEBUG
//#if !defined(_CHECKED) && _SECURE_SCL != 0
//#error "_SECURE_SCL should be disabled for release builds"
//#endif
//#endif
#ifndef UNDER_CE // fixed-buffer overloads not available for wince
#ifdef _CRT_SECURE_CPP_OVERLOAD_STANDARD_NAMES // fixed-buffer overloads for strcpy() etc.
#undef _CRT_SECURE_CPP_OVERLOAD_STANDARD_NAMES
#endif
#define _CRT_SECURE_CPP_OVERLOAD_STANDARD_NAMES 1
#endif
#pragma warning (push)
#pragma warning (disable: 4793) // caused by varargs
// disable certain parts of basetypes for wince compilation
#ifdef UNDER_CE
#define BASETYPES_NO_UNSAFECRTOVERLOAD // disable unsafe CRT overloads (safe functions don't exist in wince)
#define BASETYPES_NO_STRPRINTF // dependent functions here are not defined for wince
#endif
#ifndef OACR // dummies when we are not compiling under Office
#define OACR_WARNING_SUPPRESS(x, y)
#define OACR_WARNING_DISABLE(x, y)
#define OACR_WARNING_PUSH
#define OACR_WARNING_POP
#endif
#ifndef OACR_ASSUME // this seems to be a different one
#define OACR_ASSUME(x)
#endif
// following oacr warnings are not level1 or level2-security
// in currect stage we want to ignore those warnings
// if necessay this can be fixed at later stage
// not a bug
OACR_WARNING_DISABLE(EXC_NOT_CAUGHT_BY_REFERENCE, "Not indicating a bug or security threat.");
OACR_WARNING_DISABLE(LOCALDECLHIDESLOCAL, "Not indicating a bug or security threat.");
// not reviewed
OACR_WARNING_DISABLE(MISSING_OVERRIDE, "Not level1 or level2_security.");
OACR_WARNING_DISABLE(EMPTY_DTOR, "Not level1 or level2_security.");
OACR_WARNING_DISABLE(DEREF_NULL_PTR, "Not level1 or level2_security.");
OACR_WARNING_DISABLE(INVALID_PARAM_VALUE_1, "Not level1 or level2_security.");
OACR_WARNING_DISABLE(VIRTUAL_CALL_IN_CTOR, "Not level1 or level2_security.");
OACR_WARNING_DISABLE(POTENTIAL_ARGUMENT_TYPE_MISMATCH, "Not level1 or level2_security.");
// determine WIN32 api calling convention
// it seems this is normally stdcall?? but when compiling as /clr:pure or /clr:Safe
// this is not supported, so in this case, we need to use the 'default' calling convention
// TODO: can we reuse the #define of WINAPI??
#ifdef _WIN32
#ifdef _M_CEE_SAFE
#define WINAPI_CC __clrcall
#elif _M_CEE
#define WINAPI_CC __clrcall
#else
#define WINAPI_CC __stdcall
#endif
#endif
// fix some warnings in STL
#if !defined(_DEBUG) || defined(_CHECKED) || defined(_MANAGED)
#pragma warning(disable : 4702) // unreachable code
#endif
#include <stdarg.h>
#include <stdio.h>
#include <string.h> // include here because we redefine some names later
#include <string>
#include <vector>
#include <cmath> // for HUGE_VAL
#include <assert.h>
#include <map>
#ifdef __windows__
#include <windows.h> // for CRITICAL_SECTION
#include <strsafe.h> // for strbcpy() etc templates
#endif
#if __unix__
#include <strings.h>
#include <chrono>
#include <thread>
#include <unistd.h>
#include <sys/stat.h>
#include <dlfcn.h>
typedef unsigned char byte;
#endif
#pragma push_macro("STRSAFE_NO_DEPRECATE")
#define STRSAFE_NO_DEPRECATE // deprecation managed elsewhere, not by strsafe
#pragma pop_macro("STRSAFE_NO_DEPRECATE")
// CRT error handling seems to not be included in wince headers
// so we define our own imports
#ifdef UNDER_CE
// TODO: is this true - is GetLastError == errno?? - also this adds a dependency on windows.h
#define errno GetLastError()
// strerror(x) - x here is normally errno - TODO: make this return errno as a string
#define strerror(x) "strerror error but can't report error number sorry!"
#endif
#ifndef __in // dummies for sal annotations if compiler does not support it
#define __in
#define __inout_z
#define __in_count(x)
#define __inout_cap(x)
#define __inout_cap_c(x)
#endif
#ifndef __out_z_cap // non-VS2005 annotations
#define __out_cap(x)
#define __out_z_cap(x)
#define __out_cap_c(x)
#endif
#ifndef __override // and some more non-std extensions required by Office
#define __override virtual
#endif
// disable warnings for which fixing would make code less readable
#pragma warning(disable : 4290) // throw() declaration ignored
#pragma warning(disable : 4244) // conversion from typeA to typeB, possible loss of data
// ----------------------------------------------------------------------------
// basic macros
// ----------------------------------------------------------------------------
#define SAFE_DELETE(p) { if(p) { delete (p); (p)=NULL; } }
#define SAFE_RELEASE(p) { if(p) { (p)->Release(); (p)=NULL; } } // nasty! use CComPtr<>
#ifndef ASSERT
#ifdef _CHECKED // basetypes.h expects this function to be defined (it is in message.h)
extern void _CHECKED_ASSERT_error(const char * file, int line, const char * exp);
#define ASSERT(exp) ((exp)||(_CHECKED_ASSERT_error(__FILE__,__LINE__,#exp),0))
#else
#define ASSERT assert
#endif
#endif
using namespace std;
// ----------------------------------------------------------------------------
// basic data types
// ----------------------------------------------------------------------------
namespace msra { namespace basetypes {
// class ARRAY -- std::vector with array-bounds checking
// VS 2008 and above do this, so there is no longer a need for this.
template<class _ElemType>
class ARRAY : public std::vector<_ElemType>
{
#if defined (_DEBUG) || defined (_CHECKED) // debug version with range checking
static void throwOutOfBounds()
{ // (moved to separate function hoping to keep inlined code smaller
OACR_WARNING_PUSH;
OACR_WARNING_DISABLE(IGNOREDBYCOMMA, "Reviewd OK. Special trick below to show a message when assertion fails"
"[rogeryu 2006/03/24]");
OACR_WARNING_DISABLE(BOGUS_EXPRESSION_LIST, "This is intentional. [rogeryu 2006/03/24]");
ASSERT (("ARRAY::operator[] out of bounds", false));
OACR_WARNING_POP;
}
#endif
public:
ARRAY() : std::vector<_ElemType> () { }
ARRAY (int size) : std::vector<_ElemType> (size) { }
#if defined (_DEBUG) || defined (_CHECKED) // debug version with range checking
// ------------------------------------------------------------------------
// operator[]: with array-bounds checking
// ------------------------------------------------------------------------
inline _ElemType & operator[] (int index) // writing
{
if (index < 0 || index >= size()) throwOutOfBounds();
return (*(std::vector<_ElemType>*) this)[index];
}
// ------------------------------------------------------------------------
inline const _ElemType & operator[] (int index) const // reading
{
if (index < 0 || index >= size()) throwOutOfBounds();
return (*(std::vector<_ElemType>*) this)[index];
}
#endif
// ------------------------------------------------------------------------
// size(): same as base class, but returning an 'int' instead of 'size_t'
// to allow for better readable code
// ------------------------------------------------------------------------
inline int size() const
{
size_t siz = ((std::vector<_ElemType>*) this)->size();
return (int) siz;
}
};
// overload swap(), otherwise we'd fallback to 3-way assignment & possibly throw
template<class _T> inline void swap (ARRAY<_T> & L, ARRAY<_T> & R) throw()
{ swap ((std::vector<_T> &) L, (std::vector<_T> &) R); }
// class fixed_vector - non-resizable vector
template<class _T> class fixed_vector
{
_T * p; // pointer array
size_t n; // number of elements
void check (int index) const { index; ASSERT (index >= 0 && (size_t) index < n); }
void check (size_t index) const { index; ASSERT (index < n); }
// ... TODO: when I make this public, LinearTransform.h acts totally up but I cannot see where it comes from.
//fixed_vector (const fixed_vector & other) : n (0), p (NULL) { *this = other; }
public:
fixed_vector() : n (0), p (NULL) { }
void resize (int size) { clear(); if (size > 0) { p = new _T[size]; n = size; } }
void resize (size_t size) { clear(); if (size > 0) { p = new _T[size]; n = size; } }
fixed_vector (int size) : n (size), p (size > 0 ? new _T[size] : NULL) { }
fixed_vector (size_t size) : n ((int) size), p (size > 0 ? new _T[size] : NULL) { }
~fixed_vector() { delete[] p; }
inline int size() const { return (int) n; }
inline int capacity() const { return (int) n; }
inline bool empty() const { return n == 0; }
void clear() { delete[] p; p = NULL; n = 0; }
_T * begin() { return p; }
const _T * begin() const { return p; }
_T * end() { return p + n; } // note: n == 0 so result is NULL
inline _T & operator[] (int index) { check (index); return p[index]; } // writing
inline const _T & operator[] (int index) const { check (index); return p[index]; } // reading
inline _T & operator[] (size_t index) { check (index); return p[index]; } // writing
inline const _T & operator[] (size_t index) const { check (index); return p[index]; } // reading
inline int indexof (const _T & elem) const { ASSERT (&elem >= p && &elem < p + n); return &elem - p; }
inline void swap (fixed_vector & other) throw() { std::swap (other.p, p); std::swap (other.n, n); }
template<class VECTOR> fixed_vector & operator= (const VECTOR & other)
{
int other_n = (int) other.size();
fixed_vector tmp (other_n);
for (int k = 0; k < other_n; k++) tmp[k] = other[k];
swap (tmp);
return *this;
}
fixed_vector & operator= (const fixed_vector & other)
{
int other_n = (int) other.size();
fixed_vector tmp (other_n);
for (int k = 0; k < other_n; k++) tmp[k] = other[k];
swap (tmp);
return *this;
}
template<class VECTOR> fixed_vector (const VECTOR & other) : n (0), p (NULL) { *this = other; }
};
template<class _T> inline void swap (fixed_vector<_T> & L, fixed_vector<_T> & R) throw() { L.swap (R); }
// class matrix - simple fixed-size 2-dimensional array, access elements as m(i,j)
// stored as concatenation of rows
template<class T> class matrix : fixed_vector<T>
{
size_t numcols;
size_t locate (size_t i, size_t j) const { ASSERT (i < rows() && j < cols()); return i * cols() + j; }
public:
typedef T elemtype;
matrix() : numcols (0) {}
matrix (size_t n, size_t m) { resize (n, m); }
void resize (size_t n, size_t m) { numcols = m; fixed_vector<T>::resize (n * m); }
size_t cols() const { return numcols; }
size_t rows() const { return empty() ? 0 : size() / cols(); }
size_t size() const { return fixed_vector<T>::size(); } // use this for reading and writing... not nice!
bool empty() const { return fixed_vector<T>::empty(); }
T & operator() (size_t i, size_t j) { return (*this)[locate(i,j)]; }
const T & operator() (size_t i, size_t j) const { return (*this)[locate(i,j)]; }
void swap (matrix & other) throw() { std::swap (numcols, other.numcols); fixed_vector<T>::swap (other); }
};
template<class _T> inline void swap (matrix<_T> & L, matrix<_T> & R) throw() { L.swap (R); }
// TODO: get rid of these
typedef std::string STRING;
typedef std::wstring WSTRING;
#ifdef __unix__
typedef wchar_t TCHAR;
#endif
typedef std::basic_string<TCHAR> TSTRING; // wide/narrow character string
// derive from this for noncopyable classes (will get you private unimplemented copy constructors)
// ... TODO: change all of basetypes classes/structs to use this
class noncopyable
{
noncopyable & operator= (const noncopyable &);
noncopyable (const noncopyable &);
public:
noncopyable(){}
};
struct throw_hr
{
const char * msg;
inline throw_hr (const char * msg = NULL) : msg (msg) {}
};
// back-mapping of exceptions to HRESULT codes
// usage pattern: HRESULT COM_function (...) { try { exception-based function body } catch_hr_return; }
#define catch_hr_return \
catch (const bad_alloc &) { return E_OUTOFMEMORY; } \
catch (const bad_hr & e) { return e.hr; } \
catch (const invalid_argument &) { return E_INVALIDARG; } \
catch (const runtime_error &) { return E_FAIL; } \
catch (const logic_error &) { return E_UNEXPECTED; } \
catch (const exception &) { return E_FAIL; } \
return S_OK;
};}; // namespace
#ifndef BASETYPES_NO_UNSAFECRTOVERLOAD // if on, no unsafe CRT overload functions
// ----------------------------------------------------------------------------
// overloads for "unsafe" CRT functions used in our code base
// ----------------------------------------------------------------------------
// strlen/wcslen overloads for fixed-buffer size
// Note: Careful while fixing bug related to these templates.
// In all attempted experiments, in seems all 6 definitions are required
// below to get the correct behaviour. Be very very careful
// not to delete something without testing that case 5&6 have "size" deduced.
// 1. char *
// 2. char * const
// 3. const char *
// 4. const char * const
// 5. char (&) [size]
// 6. const char (&) [size]
// the following includes all headers that use strlen() and fail because of the mapping below
// to find those, change #define strlen strlen_ to something invalid e.g. strlen::strlen_
#if _MSC_VER >= 1600 // VS 2010 --TODO: fix this by correct include order instead
#include <intrin.h> // defines strlen() as an intrinsic in VS 2010
#include <typeinfo> // uses strlen()
#include <xlocale> // uses strlen()
#endif
#define strlen strlen_
template<typename _T>
size_t strlen_(_T &s) { return strnlen_s(static_cast<const char *>(s), SIZE_MAX); } // never be called but needed to keep compiler happy
template<typename _T> inline size_t strlen_(const _T &s) { return strnlen(static_cast<const char *>(s), SIZE_MAX); }
template<> inline size_t strlen_(char * &s) { return strnlen(s, SIZE_MAX); }
template<> inline size_t strlen_(const char * &s) { return strnlen(s, SIZE_MAX); }
template<size_t n> inline size_t strlen_(const char (&s)[n]) { return (strnlen(s, n)); }
template<size_t n> inline size_t strlen_(char (&s)[n]) { return (strnlen(s, n)); }
#define wcslen wcslen_
template<typename _T>
size_t wcslen_(_T &s) { return wcsnlen_s(static_cast<const wchar_t *>(s), SIZE_MAX); } // never be called but needed to keep compiler happy
template<> inline size_t wcslen_(wchar_t * &s) { return wcsnlen(s, SIZE_MAX); }
template<> inline size_t wcslen_(const wchar_t * &s) { return wcsnlen(s, SIZE_MAX); }
template<size_t n> inline size_t wcslen_(const wchar_t (&s)[n]) { return (wcsnlen(s, n)); }
template<size_t n> inline size_t wcslen_(wchar_t (&s)[n]) { return (wcsnlen(s, n)); }
// xscanf wrappers -- one overload for each actual use case in our code base
static inline int sscanf (const char * buf, const char * format, int * i1) { return sscanf (buf, format, i1); }
static inline int sscanf (const char * buf, const char * format, int * i1, int * i2) { return sscanf (buf, format, i1, i2); }
static inline int sscanf (const char * buf, const char * format, int * i1, int * i2, int * i3) { return sscanf (buf, format, i1, i2, i3); }
static inline int sscanf (const char * buf, const char * format, double * f1) { return sscanf (buf, format, f1); }
static inline int swscanf (const wchar_t * buf, const wchar_t * format, int * i1) { return swscanf (buf, format, i1); }
static inline int fscanf (FILE * file, const char * format, float * f1) { return fscanf (file, format, f1); }
// cacpy -- fixed-size character array (same as original strncpy (dst, src, sizeof (dst)))
// NOTE: THIS FUNCTION HAS NEVER BEEN TESTED. REMOVE THIS COMMENT ONCE IT HAS.
template<class T, size_t n> static inline void cacpy (T (&dst)[n], const T * src)
{ for (int i = 0; i < n; i++) { dst[i] = *src; if (*src) src++; } }
// { return strncpy (dst, src, n); } // using original C std lib function
#endif
// ----------------------------------------------------------------------------
// frequently missing string functions
// ----------------------------------------------------------------------------
namespace msra { namespace strfun {
#ifndef BASETYPES_NO_STRPRINTF
template<typename C> struct basic_cstring : public std::basic_string<C>
{
template<typename S> basic_cstring (S p) : std::basic_string<C> (p) { }
operator const C * () const { return this->c_str(); }
};
typedef basic_cstring<char> cstring;
typedef basic_cstring<wchar_t> wcstring;
// [w]strprintf() -- like sprintf() but resulting in a C++ string
template<class _T> struct _strprintf : public std::basic_string<_T>
{ // works for both wchar_t* and char*
_strprintf (const _T * format, ...)
{
va_list args; va_start (args, format); // varargs stuff
size_t n = _cprintf (format, args); // num chars excl. '\0'
const int FIXBUF_SIZE = 128; // incl. '\0'
if (n < FIXBUF_SIZE)
{
_T fixbuf[FIXBUF_SIZE];
this->assign (_sprintf (&fixbuf[0], sizeof (fixbuf)/sizeof (*fixbuf), format, args), n);
}
else // too long: use dynamically allocated variable-size buffer
{
std::vector<_T> varbuf (n + 1); // incl. '\0'
this->assign (_sprintf (&varbuf[0], varbuf.size(), format, args), n);
}
}
private:
// helpers
inline size_t _cprintf (const wchar_t * format, va_list args) { return _vscwprintf (format, args); }
inline size_t _cprintf (const char * format, va_list args) { return _vscprintf (format, args); }
inline const wchar_t * _sprintf (wchar_t * buf, size_t bufsiz, const wchar_t * format, va_list args) { vswprintf_s (buf, bufsiz, format, args); return buf; }
inline const char * _sprintf ( char * buf, size_t bufsiz, const char * format, va_list args) { vsprintf_s (buf, bufsiz, format, args); return buf; }
};
typedef strfun::_strprintf<char> strprintf; // char version
typedef strfun::_strprintf<wchar_t> wstrprintf; // wchar_t version
#endif
//http://www.nanobit.net/putty/doxy/PUTTY_8H-source.html
#ifndef CP_UTF8
#define CP_UTF8 65001
#endif
// string-encoding conversion functions
#ifdef _WIN32
struct utf8 : std::string { utf8 (const std::wstring & p) // utf-16 to -8
{
size_t len = p.length();
if (len == 0) { return;} // empty string
msra::basetypes::fixed_vector<char> buf (3 * len + 1); // max: 1 wchar => up to 3 mb chars
// ... TODO: this fill() should be unnecessary (a 0 is appended)--but verify
std::fill (buf.begin (), buf.end (), 0);
int rc = WideCharToMultiByte (CP_UTF8, 0, p.c_str(), (int) len,
&buf[0], (int) buf.size(), NULL, NULL);
if (rc == 0) throw std::runtime_error ("WideCharToMultiByte");
(*(std::string*)this) = &buf[0];
}};
struct utf16 : std::wstring { utf16 (const std::string & p) // utf-8 to -16
{
size_t len = p.length();
if (len == 0) { return;} // empty string
msra::basetypes::fixed_vector<wchar_t> buf (len + 1);
// ... TODO: this fill() should be unnecessary (a 0 is appended)--but verify
std::fill (buf.begin (), buf.end (), (wchar_t) 0);
int rc = MultiByteToWideChar (CP_UTF8, 0, p.c_str(), (int) len,
&buf[0], (int) buf.size());
if (rc == 0) throw std::runtime_error ("MultiByteToWideChar");
ASSERT (rc < buf.size ());
(*(std::wstring*)this) = &buf[0];
}};
#endif
#pragma warning(push)
#pragma warning(disable : 4996) // Reviewed by Yusheng Li, March 14, 2006. depr. fn (wcstombs, mbstowcs)
static inline std::string wcstombs (const std::wstring & p) // output: MBCS
{
size_t len = p.length();
msra::basetypes::fixed_vector<char> buf (2 * len + 1); // max: 1 wchar => 2 mb chars
std::fill (buf.begin (), buf.end (), 0);
::wcstombs (&buf[0], p.c_str(), 2 * len + 1);
return std::string (&buf[0]);
}
static inline std::wstring mbstowcs (const std::string & p) // input: MBCS
{
size_t len = p.length();
msra::basetypes::fixed_vector<wchar_t> buf (len + 1); // max: >1 mb chars => 1 wchar
std::fill (buf.begin (), buf.end (), (wchar_t) 0);
OACR_WARNING_SUPPRESS(UNSAFE_STRING_FUNCTION, "Reviewed OK. size checked. [rogeryu 2006/03/21]");
::mbstowcs (&buf[0], p.c_str(), len + 1);
return std::wstring (&buf[0]);
}
#pragma warning(pop)
static inline std::string utf8 (const std::wstring & p) { return msra::strfun::wcstombs (p.c_str()); } // output: UTF-8... not really
static inline std::wstring utf16 (const std::string & p) { return msra::strfun::mbstowcs(p.c_str()); } // input: UTF-8... not really
// split and join -- tokenize a string like strtok() would, join() strings together
template<class _T> static inline std::vector<std::basic_string<_T>> split (const std::basic_string<_T> & s, const _T * delim)
{
std::vector<std::basic_string<_T>> res;
for (size_t st = s.find_first_not_of (delim); st != std::basic_string<_T>::npos; )
{
size_t en = s.find_first_of (delim, st +1);
if (en == std::basic_string<_T>::npos) en = s.length();
res.push_back (s.substr (st, en-st));
st = s.find_first_not_of (delim, en +1); // may exceed
}
return res;
}
template<class _T> static inline std::basic_string<_T> join (const std::vector<std::basic_string<_T>> & a, const _T * delim)
{
std::basic_string<_T> res;
for (int i = 0; i < (int) a.size(); i++)
{
if (i > 0) res.append (delim);
res.append (a[i]);
}
return res;
}
#ifdef _WIN32
// parsing strings to numbers
static inline int toint (const wchar_t * s)
{
return _wtoi (s); // ... TODO: check it
}
#endif
static inline int toint (const char * s)
{
return atoi (s); // ... TODO: check it
}
static inline int toint (const std::wstring & s) { return toint (s.c_str()); }
static inline double todouble (const char * s)
{
char * ep; // will be set to point to first character that failed parsing
double value = strtod (s, &ep);
if (*s == 0 || *ep != 0)
throw std::runtime_error ("todouble: invalid input string");
return value;
}
// TODO: merge this with todouble(const char*) above
static inline double todouble (const std::string & s)
{
s.size(); // just used to remove the unreferenced warning
double value = 0.0;
// stod supposedly exists in VS2010, but some folks have compilation errors
// If this causes errors again, change the #if into the respective one for VS 2010.
#if _MSC_VER > 1400 // VS 2010+
size_t * idx = 0;
value = std::stod (s, idx);
if (idx) throw std::runtime_error ("todouble: invalid input string");
#else
char *ep = 0; // will be updated by strtod to point to first character that failed parsing
value = strtod (s.c_str(), &ep);
// strtod documentation says ep points to first unconverted character OR
// return value will be +/- HUGE_VAL for overflow/underflow
if (ep != s.c_str() + s.length() || value == HUGE_VAL || value == -HUGE_VAL)
throw std::runtime_error ("todouble: invalid input string");
#endif
return value;
}
static inline double todouble (const std::wstring & s)
{
wchar_t * endptr;
double value = wcstod (s.c_str(), &endptr);
if (*endptr) throw std::runtime_error ("todouble: invalid input string");
return value;
}
// ----------------------------------------------------------------------------
// tokenizer -- utility for white-space tokenizing strings in a character buffer
// This simple class just breaks a string, but does not own the string buffer.
// ----------------------------------------------------------------------------
class tokenizer : public std::vector<char*>
{
const char * delim;
public:
tokenizer (const char * delim, size_t cap) : delim (delim) { reserve (cap); }
// Usage: tokenizer tokens (delim, capacity); tokens = buf; tokens.size(), tokens[i]
void operator= (char * buf)
{
resize (0);
// strtok_s not available on all platforms - so backoff to strtok on those
#ifdef strtok_s
char * context; // for strtok_s()
for (char * p = strtok_s (buf, delim, &context); p; p = strtok_s (NULL, delim, &context))
push_back (p);
#else
for (char * p = strtok (buf, delim); p; p = strtok (NULL, delim))
push_back (p);
#endif
}
};
};}; // namespace
static inline msra::strfun::cstring charpath (const std::wstring & p)
{
#ifdef _WIN32
return std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>>().to_bytes(p);
#else // old version, delete once we know it works
size_t len = p.length();
std::vector<char> buf(2 * len + 1, 0); // max: 1 wchar => 2 mb chars
::wcstombs(buf.data(), p.c_str(), 2 * len + 1);
return msra::strfun::cstring (&buf[0]);
#endif
}
static inline FILE* _wfopen (const wchar_t * path, const wchar_t * mode) { return fopen(charpath(path), charpath(mode)); }
static inline void Sleep (size_t ms) { std::this_thread::sleep_for (std::chrono::milliseconds (ms)); }
// ----------------------------------------------------------------------------
// wrappers for some basic types (files, handles, timer)
// ----------------------------------------------------------------------------
namespace msra { namespace basetypes {
// FILE* with auto-close; use auto_file_ptr instead of FILE*.
// Warning: do not pass an auto_file_ptr to a function that calls fclose(),
// except for fclose() itself.
class auto_file_ptr
{
FILE * f;
FILE * operator= (auto_file_ptr &); // can't ref-count: no assignment
auto_file_ptr (auto_file_ptr &);
// implicit close (destructor, assignment): we ignore error
void close() throw() { if (f) try { if (f != stdin && f != stdout && f != stderr) ::fclose (f); } catch (...) { } f = NULL; }
void openfailed (const std::string & path) { throw std::runtime_error ("auto_file_ptr: error opening file '" + path + "': " + strerror (errno)); }
protected:
friend int fclose (auto_file_ptr&); // explicit close (note: may fail)
int fclose() { int rc = ::fclose (f); if (rc == 0) f = NULL; return rc; }
public:
auto_file_ptr() : f (NULL) { }
~auto_file_ptr() { close(); }
auto_file_ptr (const char * path, const char * mode) { f = fopen (path, mode); if (f == NULL) openfailed (path); }
auto_file_ptr (const wchar_t * wpath, const char * mode) { f = _wfopen (wpath, msra::strfun::utf16 (mode).c_str()); if (f == NULL) openfailed (msra::strfun::utf8 (wpath)); }
FILE * operator= (FILE * other) { close(); f = other; return f; }
auto_file_ptr (FILE * other) : f (other) { }
operator FILE * () const { return f; }
FILE * operator->() const { return f; }
void swap (auto_file_ptr & other) throw() { std::swap (f, other.f); }
};
inline int fclose (auto_file_ptr & af) { return af.fclose(); }
};};
namespace msra { namespace files {
// ----------------------------------------------------------------------------
// textreader -- simple reader for text files --we need this all the time!
// Currently reads 8-bit files, but can return as wstring, in which case
// they are interpreted as UTF-8 (without BOM).
// Note: Not suitable for pipes or typed input due to readahead (fixable if needed).
// ----------------------------------------------------------------------------
class textreader
{
msra::basetypes::auto_file_ptr f;
std::vector<char> buf; // read buffer (will only grow, never shrink)
int ch; // next character (we need to read ahead by one...)
char getch() { char prevch = (char) ch; ch = fgetc (f); return prevch; }
public:
textreader (const std::wstring & path) : f (path.c_str(), "rb") { buf.reserve (10000); ch = fgetc (f); }
operator bool() const { return ch != EOF; } // true if still a line to read
std::string getline() // get and consume the next line
{
if (ch == EOF) throw std::logic_error ("textreader: attempted to read beyond EOF");
assert (buf.empty());
// get all line's characters --we recognize UNIX (LF), DOS (CRLF), and Mac (CR) convention
while (ch != EOF && ch != '\n' && ch != '\r') buf.push_back (getch());
if (ch != EOF && getch() == '\r' && ch == '\n') getch(); // consume EOLN char
std::string line (buf.begin(), buf.end());
buf.clear();
return line;
}
std::wstring wgetline() { return msra::strfun::utf16 (getline()); }
};
};};
// ----------------------------------------------------------------------------
// functional-programming style helper macros (...do this with templates?)
// ----------------------------------------------------------------------------
#define foreach_index(_i,_dat) for (int _i = 0; _i < (int) (_dat).size(); _i++)
#define map_array(_x,_expr,_y) { _y.resize (_x.size()); foreach_index(_i,_x) _y[_i]=_expr(_x[_i]); }
#define reduce_array(_x,_expr,_y) { foreach_index(_i,_x) _y = (_i==0) ? _x[_i] : _expr(_y,_x[_i]); }
// ----------------------------------------------------------------------------
// frequently missing utility functions
// ----------------------------------------------------------------------------
namespace msra { namespace util {
// to (slightly) simplify processing of command-line arguments.
// command_line args (argc, argv);
// while (args.has (1) && args[0][0] == '-') { option = args.shift(); process (option); }
// for (const wchar_t * arg = args.shift(); arg; arg = args.shift()) { process (arg); }
class command_line
{
int num;
const wchar_t * * args;
public:
command_line (int argc, wchar_t * argv[]) : num (argc), args ((const wchar_t **) argv) { shift(); }
inline int size() const { return num; }
inline bool has (int left) { return size() >= left; }
const wchar_t * shift() { if (size() == 0) return NULL; num--; return *args++; }
const wchar_t * operator[] (int i) const { return (i < 0 || i >= size()) ? NULL : args[i]; }
};
// byte-reverse a variable --reverse all bytes (intended for integral types and float)
template<typename T> static inline void bytereverse (T & v) throw()
{ // note: this is more efficient than it looks because sizeof (v[0]) is a constant
char * p = (char *) &v;
const size_t elemsize = sizeof (v);
for (int k = 0; k < elemsize / 2; k++) // swap individual bytes
swap (p[k], p[elemsize-1 - k]);
}
// byte-swap an entire array
template<class V> static inline void byteswap (V & v) throw()
{
foreach_index (i, v)
bytereverse (v[i]);
}
// execute a block with retry
// Block must be restartable.
// Use this when writing small files to those unreliable Windows servers.
// TODO: This will fail to compile under VS 2008--we need an #ifdef around this
template<typename FUNCTION> static void attempt (int retries, const FUNCTION & body)
{
for (int attempt = 1; ; attempt++)
{
try
{
body();
if (attempt > 1) fprintf (stderr, "attempt: success after %d retries\n", attempt);
break;
}
catch (const std::exception & e)
{
if (attempt >= retries)
throw; // failed N times --give up and rethrow the error
fprintf (stderr, "attempt: %s, retrying %d-th time out of %d...\n", e.what(), attempt+1, retries);
::Sleep (1000); // wait a little, then try again
}
}
}
};}; // namespace
#ifdef _WIN32
// ----------------------------------------------------------------------------
// frequently missing Win32 functions
// ----------------------------------------------------------------------------
// strerror() for Win32 error codes
static inline std::wstring FormatWin32Error (DWORD error)
{
wchar_t buf[1024] = { 0 };
::FormatMessageW (FORMAT_MESSAGE_FROM_SYSTEM, "", error, 0, buf, sizeof (buf)/sizeof (*buf) -1, NULL);
std::wstring res (buf);
// eliminate newlines (and spaces) from the end
size_t last = res.find_last_not_of (L" \t\r\n");
if (last != std::string::npos) res.erase (last +1, res.length());
return res;
}
// we always wanted this!
#pragma warning (push)
#pragma warning (disable: 6320) // Exception-filter expression is the constant EXCEPTION_EXECUTE_HANDLER
#pragma warning (disable: 6322) // Empty _except block
static inline void SetCurrentThreadName (const char* threadName)
{ // from http://msdn.microsoft.com/en-us/library/xcb2z8hs.aspx
::Sleep(10);
#pragma pack(push,8)
struct { DWORD dwType; LPCSTR szName; DWORD dwThreadID; DWORD dwFlags; } info = { 0x1000, threadName, (DWORD) -1, 0 };
#pragma pack(pop)
__try { RaiseException (0x406D1388, 0, sizeof(info)/sizeof(ULONG_PTR), (ULONG_PTR*)&info); }
__except(EXCEPTION_EXECUTE_HANDLER) { }
}
#pragma warning (pop)
// return a string as a CoTaskMemAlloc'ed memory object
// Returns NULL if out of memory (we don't throw because we'd just catch it outside and convert to HRESULT anyway).
static inline LPWSTR CoTaskMemString (const wchar_t * s)
{
size_t n = wcslen (s) + 1; // number of chars to allocate and copy
LPWSTR p = (LPWSTR) ::CoTaskMemAlloc (sizeof (*p) * n);
if (p) for (size_t i = 0; i < n; i++) p[i] = s[i];
return p;
}
template<class S> static inline void ZeroStruct (S & s) { memset (&s, 0, sizeof (s)); }
#endif
// ----------------------------------------------------------------------------
// machine dependent
// ----------------------------------------------------------------------------
#define MACHINE_IS_BIG_ENDIAN (false)
using namespace msra::basetypes; // for compatibility
#pragma warning (pop)
// RuntimeError - throw a std::runtime_error with a formatted error string
#ifdef _MSC_VER
__declspec(noreturn)
#endif
static inline void RuntimeError(const char * format, ...)
{
va_list args;
char buffer[1024];
va_start(args, format);
vsprintf(buffer, format, args);
throw std::runtime_error(buffer);
};
// LogicError - throw a std::logic_error with a formatted error string
#ifdef _MSC_VER
__declspec(noreturn)
#endif
static inline void LogicError(const char * format, ...)
{
va_list args;
char buffer[1024];
va_start(args, format);
vsprintf(buffer, format, args);
throw std::logic_error(buffer);
};
// ----------------------------------------------------------------------------
// dynamic loading of modules
// ----------------------------------------------------------------------------
#ifdef _WIN32
class Plugin
{
HMODULE m_hModule; // module handle for the writer DLL
std::wstring m_dllName; // name of the writer DLL
public:
Plugin() { m_hModule = NULL; }
template<class STRING> // accepts char (UTF-8) and wide string
FARPROC Load(const STRING & plugin, const std::string & proc)
{
m_dllName = msra::strfun::utf16(plugin);
m_dllName += L".dll";
m_hModule = LoadLibrary(m_dllName.c_str());
if (m_hModule == NULL)
RuntimeError("Plugin not found: %s", msra::strfun::utf8(m_dllName));
// create a variable of each type just to call the proper templated version
return GetProcAddress(m_hModule, proc.c_str());
}
~Plugin() { if (m_hModule) FreeLibrary(m_hModule); }
};
#else
class Plugin
{
public:
template<class STRING> // accepts char (UTF-8) and wide string
void * Load(const STRING & plugin, const std::string & proc)
{
RuntimeError("Plugins not implemented on Linux yet");
return nullptr;
}
};
#endif
#endif // _BASETYPES_

Просмотреть файл

@ -1,122 +0,0 @@
//
// <copyright file="biggrowablevectors.h" company="Microsoft">
// Copyright (c) Microsoft Corporation. All rights reserved.
// </copyright>
//
// biggrowablevectors.h -- big growable vector that uses two layers and optionally a disk backing store for paging
#pragma once
namespace msra { namespace dbn {
// ---------------------------------------------------------------------------
// growablevectorbase -- helper for two-layer growable random-access array
// This allows both a fully allocated vector (with push_back()), e.g. for uids,
// as well as a partially allocated one (content managed by derived class), for features and lattice blocks.
// TODO:
// - test this (make copy of binary first before full compilation; or rebuild the previous version)
// - fully move in-mem range here, test again
// - then we can move towards paging from archive directly (biggrowablevectorarray gets tossed)
// ---------------------------------------------------------------------------
template<class BLOCKTYPE> class growablevectorbase
{
protected: // fix this later
const size_t elementsperblock;
size_t n; // number of elements
std::vector<std::unique_ptr<BLOCKTYPE>> blocks; // the data blocks
void operator= (const growablevectorbase &); // (non-assignable)
void check (size_t t) const { if (t >= n) throw std::logic_error ("growablevectorbase: out of bounds"); } // bounds check helper
// resize intermediate level, but do not allocate blocks
// (may deallocate if shrinking)
void resize_without_commit (size_t T)
{
blocks.resize ((T + elementsperblock-1) / elementsperblock);
n = T;
// TODO: update allocated range
}
// commit memory
// begin/end must be block boundaries
void commit (size_t begin, size_t end, BLOCKTYPE * blockdata)
{
auto blockptr = getblock (begin, end); // memory leak: if this fails (logic error; should never happen)
blockptr.set (blockdata); // take ownership of the block
// TODO: update allocated range --also enforce consecutiveness
}
// flush a block
// begin/end must be block boundaries
void flush (size_t begin, size_t end)
{
auto blockptr = getblock (begin, end); // memory leak: if this fails (logic error; should never happen)
blockptr.reset(); // release it
// TODO: update allocated range --also enforce consecutiveness
}
// helper to get a block pointer, with block referenced as its entire range
std::unique_ptr<BLOCKTYPE> & getblockptr (size_t t) // const
{
check (t);
return blocks[t / elementsperblock];
}
// helper to get a block pointer, with block referenced as its entire range
std::unique_ptr<BLOCKTYPE> & getblockptr (size_t begin, size_t end) const
{
// BUGBUG: last block may be shorter than elementsperblock
if (end - begin != elementsperblock || getblockt (begin) != 0)
throw std::logic_error ("growablevectorbase: non-block boundaries passed to block-level function");
return getblockptr (begin);
}
public:
growablevectorbase (size_t elementsperblock) : elementsperblock (elementsperblock), n (0) { blocks.reserve (1000); }
size_t size() const { return n; } // number of frames
bool empty() const { return size() == 0; }
// to access an element t -> getblock(t)[getblockt(t)]
BLOCKTYPE & getblock (size_t t) const
{
check (t);
const size_t blockid = t / elementsperblock;
return *blocks[blockid].get();
}
size_t getblockt (size_t t) const
{
check (t);
return t % elementsperblock;
}
};
// ---------------------------------------------------------------------------
// biggrowablevector -- big vector we can push_back to
// ---------------------------------------------------------------------------
template<typename ELEMTYPE> class biggrowablevector : public growablevectorbase<std::vector<ELEMTYPE>>
{
public:
biggrowablevector() : growablevectorbase<std::vector<ELEMTYPE>>::growablevectorbase (65536) { }
template<typename VALTYPE> void push_back (VALTYPE e) // VALTYPE could be an rvalue reference
{
size_t i = this->size();
this->resize_without_commit (i + 1);
auto & block = this->getblockptr (i);
if (block.get() == NULL)
block.reset (new std::vector<ELEMTYPE> (this->elementsperblock));
(*block)[this->getblockt (i)] = e;
}
ELEMTYPE & operator[] (size_t t) { return this->getblock(t)[this->getblockt (t)]; } // get an element
const ELEMTYPE & operator[] (size_t t) const { return this->getblock(t)[this->getblockt (t)]; } // get an element
void resize (const size_t n)
{
this->resize_without_commit (n);
foreach_index (i, this->blocks)
if (this->blocks[i].get() == NULL)
this->blocks[i].reset (new std::vector<ELEMTYPE> (this->elementsperblock));
}
};
};};

Просмотреть файл

@ -1,373 +0,0 @@
//
// <copyright file="chunkevalsource.h" company="Microsoft">
// Copyright (c) Microsoft Corporation. All rights reserved.
// </copyright>
//
#pragma once
//#include <objbase.h>
#include "basetypes.h" // for attempt()
#include "htkfeatio.h" // for reading HTK features
#include "minibatchsourcehelpers.h"
#ifndef __unix__
#include "ssematrix.h"
#endif
#ifdef LEAKDETECT
#include <vld.h> // for memory leak detection
#endif
namespace msra { namespace dbn {
class chunkevalsource // : public numamodelmanager
{
const size_t chunksize; // actual block size to perform computation on
// data FIFO
msra::dbn::matrix feat;
std::vector<std::vector<float>> frames; // [t] all feature frames concatenated into a big block
std::vector<char> boundaryflags; // [t] -1 for first and +1 last frame, 0 else (for augmentneighbors())
std::vector<size_t> numframes; // [k] number of frames for all appended files
std::vector<std::wstring> outpaths; // [k] and their pathnames
std::vector<unsigned int> sampperiods; // [k] and sample periods (they should really all be the same...)
size_t vdim; // input dimension
size_t udim; // output dimension
bool minibatchready;
void operator=(const chunkevalsource &);
private:
void clear() // empty the FIFO
{
frames.clear();
boundaryflags.clear();
numframes.clear();
outpaths.clear();
sampperiods.clear();
minibatchready=false;
}
void saveandflush(msra::dbn::matrix &pred)
{
const size_t framesinblock = frames.size();
// write out all files
size_t firstframe = 0;
foreach_index (k, numframes)
{
const wstring & outfile = outpaths[k];
unsigned int sampperiod = sampperiods[k];
size_t n = numframes[k];
msra::files::make_intermediate_dirs (outfile);
fprintf (stderr, "saveandflush: writing %zu frames to %S\n", n, outfile.c_str());
msra::dbn::matrixstripe thispred (pred, firstframe, n);
// some sanity check for the data we've written
const size_t nansinf = thispred.countnaninf();
if (nansinf > 0)
fprintf (stderr, "chunkeval: %d NaNs or INF detected in '%S' (%d frames)\n", (int) nansinf, outfile.c_str(), (int) thispred.cols());
// save it
msra::util::attempt (5, [&]()
{
msra::asr::htkfeatwriter::write (outfile, "USER", sampperiod, thispred);
});
firstframe += n;
}
assert (firstframe == framesinblock); framesinblock;
// and we are done --forget the FIFO content & get ready for next chunk
clear();
}
public:
chunkevalsource (size_t numinput, size_t numoutput, size_t chunksize)
:vdim(numinput),udim(numoutput),chunksize(chunksize)
{
frames.reserve (chunksize * 2);
feat.resize(vdim,chunksize); // initialize to size chunksize
}
// append data to chunk
template<class MATRIX> void addfile (const MATRIX & feat, const string & featkind, unsigned int sampperiod, const std::wstring & outpath)
{
// append to frames; also expand neighbor frames
if (feat.cols() < 2)
throw std::runtime_error ("evaltofile: utterances < 2 frames not supported");
foreach_column (t, feat)
{
std::vector<float> v (&feat(0,t), &feat(0,t) + feat.rows());
frames.push_back (v);
boundaryflags.push_back ((t == 0) ? -1 : (t == feat.cols() -1) ? +1 : 0);
}
numframes.push_back (feat.cols());
outpaths.push_back (outpath);
sampperiods.push_back (sampperiod);
}
void createevalminibatch()
{
const size_t framesinblock = frames.size();
feat.resize(vdim, framesinblock); // input features for whole utt (col vectors)
// augment the features
msra::dbn::augmentneighbors (frames, boundaryflags, 0, framesinblock, feat);
minibatchready=true;
}
void writetofiles(msra::dbn::matrix &pred){ saveandflush(pred); }
msra::dbn::matrix chunkofframes() { assert(minibatchready); return feat; }
bool isminibatchready() { return minibatchready; }
size_t currentchunksize() { return frames.size(); }
void flushinput(){createevalminibatch();}
void reset() { clear(); }
};
class chunkevalsourcemulti // : public numamodelmanager
{
const size_t chunksize; // actual block size to perform computation on
// data FIFO
std::vector<msra::dbn::matrix> feat;
std::vector<std::vector<std::vector<float>>> framesmulti; // [t] all feature frames concatenated into a big block
std::vector<char> boundaryflags; // [t] -1 for first and +1 last frame, 0 else (for augmentneighbors())
std::vector<size_t> numframes; // [k] number of frames for all appended files
std::vector<std::vector<std::wstring>> outpaths; // [k] and their pathnames
std::vector<std::vector<unsigned int>> sampperiods; // [k] and sample periods (they should really all be the same...)
std::vector<size_t> vdims; // input dimension
std::vector<size_t> udims; // output dimension
bool minibatchready;
void operator=(const chunkevalsourcemulti &);
private:
void clear() // empty the FIFO
{
foreach_index(i, vdims)
{
framesmulti[i].clear();
outpaths[i].clear();
sampperiods[i].clear();
}
boundaryflags.clear();
numframes.clear();
minibatchready=false;
}
void saveandflush(msra::dbn::matrix &pred, size_t index)
{
const size_t framesinblock = framesmulti[index].size();
// write out all files
size_t firstframe = 0;
foreach_index (k, numframes)
{
const wstring & outfile = outpaths[index][k];
unsigned int sampperiod = sampperiods[index][k];
size_t n = numframes[k];
msra::files::make_intermediate_dirs (outfile);
fprintf (stderr, "saveandflush: writing %zu frames to %S\n", n, outfile.c_str());
msra::dbn::matrixstripe thispred (pred, firstframe, n);
// some sanity check for the data we've written
const size_t nansinf = thispred.countnaninf();
if (nansinf > 0)
fprintf (stderr, "chunkeval: %d NaNs or INF detected in '%S' (%d frames)\n", (int) nansinf, outfile.c_str(), (int) thispred.cols());
// save it
msra::util::attempt (5, [&]()
{
msra::asr::htkfeatwriter::write (outfile, "USER", sampperiod, thispred);
});
firstframe += n;
}
assert (firstframe == framesinblock); framesinblock;
// and we are done --forget the FIFO content & get ready for next chunk
}
public:
chunkevalsourcemulti (std::vector<size_t> vdims, std::vector<size_t> udims, size_t chunksize)
:vdims(vdims),udims(udims),chunksize(chunksize)
{
foreach_index(i, vdims)
{
msra::dbn::matrix thisfeat;
std::vector<std::vector<float>> frames; // [t] all feature frames concatenated into a big block
frames.reserve(chunksize * 2);
framesmulti.push_back(frames);
//framesmulti[i].reserve (chunksize * 2);
thisfeat.resize(vdims[i], chunksize);
feat.push_back(thisfeat);
outpaths.push_back(std::vector<std::wstring>());
sampperiods.push_back(std::vector<unsigned int>());
//feat[i].resize(vdims[i],chunksize); // initialize to size chunksize
}
}
// append data to chunk
template<class MATRIX> void addfile (const MATRIX & feat, const string & featkind, unsigned int sampperiod, const std::wstring & outpath, size_t index)
{
// append to frames; also expand neighbor frames
if (feat.cols() < 2)
throw std::runtime_error ("evaltofile: utterances < 2 frames not supported");
foreach_column (t, feat)
{
std::vector<float> v (&feat(0,t), &feat(0,t) + feat.rows());
framesmulti[index].push_back (v);
if (index==0)
boundaryflags.push_back ((t == 0) ? -1 : (t == feat.cols() -1) ? +1 : 0);
}
if (index==0)
numframes.push_back (feat.cols());
outpaths[index].push_back (outpath);
sampperiods[index].push_back (sampperiod);
}
void createevalminibatch()
{
foreach_index(i, framesmulti)
{
const size_t framesinblock = framesmulti[i].size();
feat[i].resize(vdims[i], framesinblock); // input features for whole utt (col vectors)
// augment the features
msra::dbn::augmentneighbors (framesmulti[i], boundaryflags, 0, framesinblock, feat[i]);
}
minibatchready=true;
}
void writetofiles(msra::dbn::matrix &pred, size_t index){ saveandflush(pred, index); }
msra::dbn::matrix chunkofframes(size_t index) { assert(minibatchready); assert(index<=feat.size()); return feat[index]; }
bool isminibatchready() { return minibatchready; }
size_t currentchunksize() { return framesmulti[0].size(); }
void flushinput(){createevalminibatch();}
void reset() { clear(); }
};
class FileEvalSource // : public numamodelmanager
{
const size_t chunksize; // actual block size to perform computation on
// data FIFO
std::vector<msra::dbn::matrix> feat;
std::vector<std::vector<std::vector<float>>> framesMulti; // [t] all feature frames concatenated into a big block
std::vector<char> boundaryFlags; // [t] -1 for first and +1 last frame, 0 else (for augmentneighbors())
std::vector<size_t> numFrames; // [k] number of frames for all appended files
std::vector<std::vector<unsigned int>> sampPeriods; // [k] and sample periods (they should really all be the same...)
std::vector<size_t> vdims; // input dimension
std::vector<size_t> leftcontext;
std::vector<size_t> rightcontext;
bool minibatchReady;
size_t minibatchSize;
size_t frameIndex;
void operator=(const FileEvalSource &);
private:
void Clear() // empty the FIFO
{
foreach_index(i, vdims)
{
framesMulti[i].clear();
sampPeriods[i].clear();
}
boundaryFlags.clear();
numFrames.clear();
minibatchReady=false;
frameIndex=0;
}
public:
FileEvalSource(std::vector<size_t> vdims, std::vector<size_t> leftcontext, std::vector<size_t> rightcontext, size_t chunksize) :vdims(vdims), leftcontext(leftcontext), rightcontext(rightcontext), chunksize(chunksize)
{
foreach_index(i, vdims)
{
msra::dbn::matrix thisfeat;
std::vector<std::vector<float>> frames; // [t] all feature frames concatenated into a big block
frames.reserve(chunksize * 2);
framesMulti.push_back(frames);
//framesmulti[i].reserve (chunksize * 2);
thisfeat.resize(vdims[i], chunksize);
feat.push_back(thisfeat);
sampPeriods.push_back(std::vector<unsigned int>());
//feat[i].resize(vdims[i],chunksize); // initialize to size chunksize
}
}
// append data to chunk
template<class MATRIX> void AddFile (const MATRIX & feat, const string & /*featkind*/, unsigned int sampPeriod, size_t index)
{
// append to frames; also expand neighbor frames
if (feat.cols() < 2)
throw std::runtime_error ("evaltofile: utterances < 2 frames not supported");
foreach_column (t, feat)
{
std::vector<float> v (&feat(0,t), &feat(0,t) + feat.rows());
framesMulti[index].push_back (v);
if (index==0)
boundaryFlags.push_back ((t == 0) ? -1 : (t == feat.cols() -1) ? +1 : 0);
}
if (index==0)
numFrames.push_back (feat.cols());
sampPeriods[index].push_back (sampPeriod);
}
void CreateEvalMinibatch()
{
foreach_index(i, framesMulti)
{
const size_t framesInBlock = framesMulti[i].size();
feat[i].resize(vdims[i], framesInBlock); // input features for whole utt (col vectors)
// augment the features
size_t leftextent, rightextent;
// page in the needed range of frames
if (leftcontext[i] == 0 && rightcontext[i] == 0)
{
leftextent = rightextent = augmentationextent(framesMulti[i][0].size(), vdims[i]);
}
else
{
leftextent = leftcontext[i];
rightextent = rightcontext[i];
}
//msra::dbn::augmentneighbors(framesMulti[i], boundaryFlags, 0, leftcontext[i], rightcontext[i],)
msra::dbn::augmentneighbors (framesMulti[i], boundaryFlags, leftextent, rightextent, 0, framesInBlock, feat[i]);
}
minibatchReady=true;
}
void SetMinibatchSize(size_t mbSize){ minibatchSize=mbSize;}
msra::dbn::matrix ChunkOfFrames(size_t index) { assert(minibatchReady); assert(index<=feat.size()); return feat[index]; }
bool IsMinibatchReady() { return minibatchReady; }
size_t CurrentFileSize() { return framesMulti[0].size(); }
void FlushInput(){CreateEvalMinibatch();}
void Reset() { Clear(); }
};
};};

Просмотреть файл

@ -1,24 +0,0 @@
//
// <copyright file="dllmain.cpp" company="Microsoft">
// Copyright (c) Microsoft Corporation. All rights reserved.
// </copyright>
//
// dllmain.cpp : Defines the entry point for the DLL application.
#include "stdafx.h"
BOOL APIENTRY DllMain( HMODULE /*hModule*/,
DWORD ul_reason_for_call,
LPVOID /*lpReserved*/
)
{
switch (ul_reason_for_call)
{
case DLL_PROCESS_ATTACH:
case DLL_THREAD_ATTACH:
case DLL_THREAD_DETACH:
case DLL_PROCESS_DETACH:
break;
}
return TRUE;
}

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -1,620 +0,0 @@
//
// fileutil.h - file I/O with error checking
//
// Copyright (c) Microsoft Corporation. All rights reserved.
//
#pragma once
#ifndef _FILEUTIL_
#define _FILEUTIL_
#include "Platform.h"
#include <stdio.h>
#ifdef __unix__
#include <sys/types.h>
#include <sys/stat.h>
#endif
#include <algorithm> // for std::find
#include <vector>
#include <map>
#include <functional>
#include <cctype>
#include <errno.h>
#include <stdint.h>
#include <assert.h>
#include <string.h> // for strerror()
using namespace std;
#define SAFE_CLOSE(f) (((f) == NULL) || (fcloseOrDie ((f)), (f) = NULL))
// ----------------------------------------------------------------------------
// fopenOrDie(): like fopen() but terminate with err msg in case of error.
// A pathname of "-" returns stdout or stdin, depending on mode, and it will
// change the binary mode if 'b' or 't' are given. If you use this, make sure
// not to fclose() such a handle.
// ----------------------------------------------------------------------------
FILE * fopenOrDie (const string & pathname, const char * mode);
FILE * fopenOrDie (const wstring & pathname, const wchar_t * mode);
#ifndef __unix__
// ----------------------------------------------------------------------------
// fsetmode(): set mode to binary or text
// ----------------------------------------------------------------------------
void fsetmode (FILE * f, char type);
#endif
// ----------------------------------------------------------------------------
// freadOrDie(): like fread() but terminate with err msg in case of error
// ----------------------------------------------------------------------------
void freadOrDie (void * ptr, size_t size, size_t count, FILE * f);
template<class _T>
void freadOrDie (_T & data, int num, FILE * f) // template for vector<>
{ data.resize (num); if (data.size() > 0) freadOrDie (&data[0], sizeof (data[0]), data.size(), f); }
template<class _T>
void freadOrDie (_T & data, size_t num, FILE * f) // template for vector<>
{ data.resize (num); if (data.size() > 0) freadOrDie (&data[0], sizeof (data[0]), data.size(), f); }
// ----------------------------------------------------------------------------
// fwriteOrDie(): like fwrite() but terminate with err msg in case of error
// ----------------------------------------------------------------------------
void fwriteOrDie (const void * ptr, size_t size, size_t count, FILE * f);
template<class _T>
void fwriteOrDie (const _T & data, FILE * f) // template for vector<>
{ if (data.size() > 0) fwriteOrDie (&data[0], sizeof (data[0]), data.size(), f); }
// ----------------------------------------------------------------------------
// fprintfOrDie(): like fprintf() but terminate with err msg in case of error
// ----------------------------------------------------------------------------
void fprintfOrDie (FILE * f, const char *format, ...);
// ----------------------------------------------------------------------------
// fcloseOrDie(): like fclose() but terminate with err msg in case of error
// not yet implemented, but we should
// ----------------------------------------------------------------------------
#define fcloseOrDie fclose
// ----------------------------------------------------------------------------
// fflushOrDie(): like fflush() but terminate with err msg in case of error
// ----------------------------------------------------------------------------
void fflushOrDie (FILE * f);
// ----------------------------------------------------------------------------
// filesize(): determine size of the file in bytes
// ----------------------------------------------------------------------------
size_t filesize (const wchar_t * pathname);
size_t filesize (FILE * f);
int64_t filesize64 (const wchar_t * pathname);
// ----------------------------------------------------------------------------
// fseekOrDie(),ftellOrDie(), fget/setpos(): seek functions with error handling
// ----------------------------------------------------------------------------
// 32-bit offsets only
long fseekOrDie (FILE * f, long offset, int mode = SEEK_SET);
#define ftellOrDie ftell
// ----------------------------------------------------------------------------
// fget/setpos(): seek functions with error handling
// ----------------------------------------------------------------------------
uint64_t fgetpos (FILE * f);
void fsetpos (FILE * f, uint64_t pos);
// ----------------------------------------------------------------------------
// unlinkOrDie(): unlink() with error handling
// ----------------------------------------------------------------------------
void unlinkOrDie (const std::string & pathname);
void unlinkOrDie (const std::wstring & pathname);
// ----------------------------------------------------------------------------
// renameOrDie(): rename() with error handling
// ----------------------------------------------------------------------------
void renameOrDie (const std::string & from, const std::string & to);
void renameOrDie (const std::wstring & from, const std::wstring & to);
// ----------------------------------------------------------------------------
// fexists(): test if a file exists
// ----------------------------------------------------------------------------
bool fexists (const char * pathname);
bool fexists (const wchar_t * pathname);
inline bool fexists (const std::string & pathname) { return fexists (pathname.c_str()); }
inline bool fexists (const std::wstring & pathname) { return fexists (pathname.c_str()); }
// ----------------------------------------------------------------------------
// funicode(): test if a file uses unicode
// ----------------------------------------------------------------------------
bool funicode (FILE * f);
// ----------------------------------------------------------------------------
// fskipspace(): skip space characters
// ----------------------------------------------------------------------------
bool fskipspace (FILE * F);
bool fskipwspace (FILE * F);
// ----------------------------------------------------------------------------
// fgetline(): like fgets() but terminate with err msg in case of error;
// removes the newline character at the end (like gets()), returned buffer is
// always 0-terminated; has second version that returns an STL string instead
// fgetstring(): read a 0-terminated string (terminate if error)
// fgetword(): read a space-terminated token (terminate if error)
// fskipNewLine(): skip all white space until end of line incl. the newline
// ----------------------------------------------------------------------------
// ----------------------------------------------------------------------------
// fputstring(): write a 0-terminated string (terminate if error)
// ----------------------------------------------------------------------------
void fputstring (FILE * f, const char *);
void fputstring (const HANDLE f, const char * str);
void fputstring (FILE * f, const std::string &);
void fputstring (FILE * f, const wchar_t *);
void fputstring (FILE * f, const std::wstring &);
template<class CHAR> CHAR * fgetline (FILE * f, CHAR * buf, int size);
template<class CHAR, size_t n> CHAR * fgetline (FILE * f, CHAR (& buf)[n]) { return fgetline (f, buf, n); }
string fgetline (FILE * f);
wstring fgetlinew (FILE * f);
void fgetline (FILE * f, std::string & s, std::vector<char> & buf);
void fgetline (FILE * f, std::wstring & s, std::vector<char> & buf);
void fgetline (FILE * f, std::vector<char> & buf);
void fgetline (FILE * f, std::vector<wchar_t> & buf);
const char * fgetstring (FILE * f, char * buf, int size);
template<size_t n> const char * fgetstring (FILE * f, char (& buf)[n]) { return fgetstring (f, buf, n); }
const char * fgetstring (const HANDLE f, char * buf, int size);
template<size_t n> const char * fgetstring (const HANDLE f, char (& buf)[n]) { return fgetstring (f, buf, n); }
const wchar_t * fgetstring (FILE * f, wchar_t * buf, int size);
wstring fgetwstring (FILE * f);
string fgetstring (FILE * f);
const char * fgettoken (FILE * f, char * buf, int size);
template<size_t n> const char * fgettoken (FILE * f, char (& buf)[n]) { return fgettoken (f, buf, n); }
string fgettoken (FILE * f);
const wchar_t * fgettoken (FILE * f, wchar_t * buf, int size);
wstring fgetwtoken (FILE * f);
int fskipNewline (FILE * f, bool skip = true);
int fskipwNewline (FILE * f, bool skip = true);
// ----------------------------------------------------------------------------
// fputstring(): write a 0-terminated string (terminate if error)
// ----------------------------------------------------------------------------
void fputstring (FILE * f, const char *);
void fputstring (FILE * f, const std::string &);
void fputstring (FILE * f, const wchar_t *);
void fputstring (FILE * f, const std::wstring &);
// ----------------------------------------------------------------------------
// fgetTag(): read a 4-byte tag & return as a string
// ----------------------------------------------------------------------------
string fgetTag (FILE * f);
// ----------------------------------------------------------------------------
// fcheckTag(): read a 4-byte tag & verify it; terminate if wrong tag
// ----------------------------------------------------------------------------
void fcheckTag (FILE * f, const char * expectedTag);
void fcheckTag_ascii (FILE * f, const string & expectedTag);
// ----------------------------------------------------------------------------
// fcompareTag(): compare two tags; terminate if wrong tag
// ----------------------------------------------------------------------------
void fcompareTag (const string & readTag, const string & expectedTag);
// ----------------------------------------------------------------------------
// fputTag(): write a 4-byte tag
// ----------------------------------------------------------------------------
void fputTag (FILE * f, const char * tag);
// ----------------------------------------------------------------------------
// fskipstring(): skip a 0-terminated string, such as a pad string
// ----------------------------------------------------------------------------
void fskipstring (FILE * f);
// ----------------------------------------------------------------------------
// fpad(): write a 0-terminated string to pad file to a n-byte boundary
// ----------------------------------------------------------------------------
void fpad (FILE * f, int n);
// ----------------------------------------------------------------------------
// fgetbyte(): read a byte value
// ----------------------------------------------------------------------------
char fgetbyte (FILE * f);
// ----------------------------------------------------------------------------
// fgetshort(): read a short value
// ----------------------------------------------------------------------------
short fgetshort (FILE * f);
short fgetshort_bigendian (FILE * f);
// ----------------------------------------------------------------------------
// fgetint24(): read a 3-byte (24-bit) int value
// ----------------------------------------------------------------------------
int fgetint24 (FILE * f);
// ----------------------------------------------------------------------------
// fgetint(): read an int value
// ----------------------------------------------------------------------------
int fgetint (FILE * f);
int fgetint_bigendian (FILE * f);
int fgetint_ascii (FILE * f);
// ----------------------------------------------------------------------------
// fgetlong(): read an long value
// ----------------------------------------------------------------------------
long fgetlong (FILE * f);
// ----------------------------------------------------------------------------
// fgetfloat(): read a float value
// ----------------------------------------------------------------------------
float fgetfloat (FILE * f);
float fgetfloat_bigendian (FILE * f);
float fgetfloat_ascii (FILE * f);
// ----------------------------------------------------------------------------
// fgetdouble(): read a double value
// ----------------------------------------------------------------------------
double fgetdouble (FILE * f);
// ----------------------------------------------------------------------------
// fputbyte(): write a byte value
// ----------------------------------------------------------------------------
void fputbyte (FILE * f, char val);
// ----------------------------------------------------------------------------
// fputshort(): write a short value
// ----------------------------------------------------------------------------
void fputshort (FILE * f, short val);
// ----------------------------------------------------------------------------
// fputint24(): write a 3-byte (24-bit) int value
// ----------------------------------------------------------------------------
void fputint24 (FILE * f, int v);
// ----------------------------------------------------------------------------
// fputint(): write an int value
// ----------------------------------------------------------------------------
void fputint (FILE * f, int val);
// ----------------------------------------------------------------------------
// fputlong(): write an long value
// ----------------------------------------------------------------------------
void fputlong (FILE * f, long val);
// ----------------------------------------------------------------------------
// fputfloat(): write a float value
// ----------------------------------------------------------------------------
void fputfloat (FILE * f, float val);
// ----------------------------------------------------------------------------
// fputdouble(): write a double value
// ----------------------------------------------------------------------------
void fputdouble (FILE * f, double val);
// template versions of put/get functions for binary files
template <typename T>
void fput(FILE * f, T v)
{
fwriteOrDie (&v, sizeof (v), 1, f);
}
// template versions of put/get functions for binary files
template <typename T>
void fget(FILE * f, T& v)
{
freadOrDie ((void *)&v, sizeof (v), 1, f);
}
// GetFormatString - get the format string for a particular type
template <typename T>
const wchar_t* GetFormatString(T /*t*/)
{
// if this _ASSERT goes off it means that you are using a type that doesn't have
// a read and/or write routine.
// If the type is a user defined class, you need to create some global functions that handles file in/out.
// for example:
//File& operator>>(File& stream, MyClass& test);
//File& operator<<(File& stream, MyClass& test);
//
// in your class you will probably want to add these functions as friends so you can access any private members
// friend File& operator>>(File& stream, MyClass& test);
// friend File& operator<<(File& stream, MyClass& test);
//
// if you are using wchar_t* or char* types, these use other methods because they require buffers to be passed
// either use std::string and std::wstring, or use the WriteString() and ReadString() methods
assert(false); // need a specialization
return NULL;
}
// GetFormatString - specalizations to get the format string for a particular type
template <> const wchar_t* GetFormatString(char);
template <> const wchar_t* GetFormatString(wchar_t);
template <> const wchar_t* GetFormatString(short);
template <> const wchar_t* GetFormatString(int);
template <> const wchar_t* GetFormatString(long);
template <> const wchar_t* GetFormatString(unsigned short);
template <> const wchar_t* GetFormatString(unsigned int);
template <> const wchar_t* GetFormatString(unsigned long);
template <> const wchar_t* GetFormatString(float);
template <> const wchar_t* GetFormatString(double);
template <> const wchar_t* GetFormatString(size_t);
template <> const wchar_t* GetFormatString(long long);
template <> const wchar_t* GetFormatString(const char*);
template <> const wchar_t* GetFormatString(const wchar_t*);
// GetScanFormatString - get the format string for a particular type
template <typename T>
const wchar_t* GetScanFormatString(T t)
{
assert(false); // need a specialization
return NULL;
}
// GetScanFormatString - specalizations to get the format string for a particular type
template <> const wchar_t* GetScanFormatString(char);
template <> const wchar_t* GetScanFormatString(wchar_t);
template <> const wchar_t* GetScanFormatString(short);
template <> const wchar_t* GetScanFormatString(int);
template <> const wchar_t* GetScanFormatString(long);
template <> const wchar_t* GetScanFormatString(unsigned short);
template <> const wchar_t* GetScanFormatString(unsigned int);
template <> const wchar_t* GetScanFormatString(unsigned long);
template <> const wchar_t* GetScanFormatString(float);
template <> const wchar_t* GetScanFormatString(double);
template <> const wchar_t* GetScanFormatString(size_t);
template <> const wchar_t* GetScanFormatString(long long);
// ----------------------------------------------------------------------------
// fgetText(): get a value from a text file
// ----------------------------------------------------------------------------
template <typename T>
void fgetText(FILE * f, T& v)
{
int rc = ftrygetText(f, v);
if (rc == 0)
throw std::runtime_error("error reading value from file (invalid format)");
else if (rc == EOF)
throw std::runtime_error(std::string("error reading from file: ") + strerror(errno));
assert(rc == 1);
}
// version to try and get a string, and not throw exceptions if contents don't match
template <typename T>
int ftrygetText(FILE * f, T& v)
{
const wchar_t* formatString = GetScanFormatString<T>(v);
int rc = fwscanf (f, formatString, &v);
assert(rc == 1 || rc == 0);
return rc;
}
template <> int ftrygetText<bool>(FILE * f, bool& v);
// ----------------------------------------------------------------------------
// fgetText() specializations for fwscanf_s differences: get a value from a text file
// ----------------------------------------------------------------------------
void fgetText(FILE * f, char& v);
void fgetText(FILE * f, wchar_t& v);
// ----------------------------------------------------------------------------
// fputText(): write a value out as text
// ----------------------------------------------------------------------------
template <typename T>
void fputText(FILE * f, T v)
{
const wchar_t* formatString = GetFormatString(v);
int rc = fwprintf(f, formatString, v);
if (rc == 0)
throw std::runtime_error("error writing value to file, no values written");
else if (rc < 0)
throw std::runtime_error(std::string("error writing to file: ") + strerror(errno));
}
// ----------------------------------------------------------------------------
// fputText(): write a bool out as character
// ----------------------------------------------------------------------------
template <> void fputText<bool>(FILE * f, bool v);
// ----------------------------------------------------------------------------
// fputfile(): write a binary block or a string as a file
// ----------------------------------------------------------------------------
void fputfile (const wstring & pathname, const std::vector<char> & buffer);
void fputfile (const wstring & pathname, const std::wstring & string);
void fputfile (const wstring & pathname, const std::string & string);
// ----------------------------------------------------------------------------
// fgetfile(): load a file as a binary block
// ----------------------------------------------------------------------------
void fgetfile (const wstring & pathname, std::vector<char> & buffer);
void fgetfile (FILE * f, std::vector<char> & buffer);
namespace msra { namespace files {
void fgetfilelines (const std::wstring & pathname, vector<char> & readbuffer, std::vector<std::string> & lines);
static inline std::vector<std::string> fgetfilelines (const std::wstring & pathname) { vector<char> buffer; std::vector<std::string> lines; fgetfilelines (pathname, buffer, lines); return lines; }
vector<char*> fgetfilelines (const wstring & pathname, vector<char> & readbuffer);
};};
// ----------------------------------------------------------------------------
// expand_wildcards() -- expand a path with wildcards (also intermediate ones)
// ----------------------------------------------------------------------------
void expand_wildcards (const wstring & path, vector<wstring> & paths);
// ----------------------------------------------------------------------------
// make_intermediate_dirs() -- make all intermediate dirs on a path
// ----------------------------------------------------------------------------
namespace msra { namespace files {
void make_intermediate_dirs (const wstring & filepath);
};};
// ----------------------------------------------------------------------------
// fuptodate() -- test whether an output file is at least as new as an input file
// ----------------------------------------------------------------------------
namespace msra { namespace files {
bool fuptodate (const wstring & target, const wstring & input, bool inputrequired = true);
};};
#if 0
// ----------------------------------------------------------------------------
// simple support for WAV file I/O
// ----------------------------------------------------------------------------
// define the header if we haven't seen it yet
#ifndef _WAVEFORMATEX_
#define _WAVEFORMATEX_
/*
* extended waveform format structure used for all non-PCM formats. this
* structure is common to all non-PCM formats.
*/
typedef unsigned short WORD; // in case not defined yet (i.e. linux)
typedef struct tWAVEFORMATEX
{
WORD wFormatTag; /* format type */
WORD nChannels; /* number of channels (i.e. mono, stereo...) */
DWORD nSamplesPerSec; /* sample rate */
DWORD nAvgBytesPerSec; /* for buffer estimation */
WORD nBlockAlign; /* block size of data */
WORD wBitsPerSample; /* number of bits per sample of mono data */
WORD cbSize; /* the count in bytes of the size of */
/* extra information (after cbSize) */
} WAVEFORMATEX, *PWAVEFORMATEX;
#endif /* _WAVEFORMATEX_ */
typedef struct wavehder{
char riffchar[4];
unsigned int RiffLength;
char wavechar[8];
unsigned int FmtLength;
signed short wFormatTag;
signed short nChannels;
unsigned int nSamplesPerSec;
unsigned int nAvgBytesPerSec;
signed short nBlockAlign;
signed short wBitsPerSample;
char datachar[4];
unsigned int DataLength;
private:
void prepareRest (int SampleCount);
public:
void prepare (unsigned int Fs, int Bits, int Channels, int SampleCount);
void prepare (const WAVEFORMATEX & wfx, int SampleCount);
unsigned int read (FILE * f, signed short & wRealFormatTag, int & bytesPerSample);
void write (FILE * f);
static void update (FILE * f);
} WAVEHEADER;
// ----------------------------------------------------------------------------
// fgetwfx(), fputwfx(): I/O of wave file headers only
// ----------------------------------------------------------------------------
unsigned int fgetwfx (FILE *f, WAVEFORMATEX & wfx);
void fputwfx (FILE *f, const WAVEFORMATEX & wfx, unsigned int numSamples);
// ----------------------------------------------------------------------------
// fgetraw(): read data of .wav file, and separate data of multiple channels.
// For example, data[i][j]: i is channel index, 0 means the first
// channel. j is sample index.
// ----------------------------------------------------------------------------
void fgetraw (FILE *f,std::vector< std::vector<short> > & data,const WAVEHEADER & wavhd);
#endif
// ----------------------------------------------------------------------------
// temp functions -- clean these up
// ----------------------------------------------------------------------------
// split a pathname into directory and filename
static inline void splitpath (const wstring & path, wstring & dir, wstring & file)
{
size_t pos = path.find_last_of (L"\\:/"); // DOS drives, UNIX, Windows
if (pos == path.npos) // no directory found
{
dir.clear();
file = path;
}
else
{
dir = path.substr (0, pos);
file = path.substr (pos +1);
}
}
// test if a pathname is a relative path
// A relative path is one that can be appended to a directory.
// Drive-relative paths, such as D:file, are considered non-relative.
static inline bool relpath (const wchar_t * path)
{ // this is a wild collection of pathname conventions in Windows
if (path[0] == '/' || path[0] == '\\') // e.g. \WINDOWS
return false;
if (path[0] && path[1] == ':') // drive syntax
return false;
// ... TODO: handle long NT paths
return true; // all others
}
template<class CHAR>
static inline bool relpath (const std::basic_string<CHAR> & s) { return relpath (s.c_str()); }
// trim from start
static inline std::string &ltrim(std::string &s) {
s.erase(s.begin(), std::find_if(s.begin(), s.end(), std::not1(std::ptr_fun<int, int>(std::isspace))));
return s;
}
// trim from end
static inline std::string &rtrim(std::string &s) {
s.erase(std::find_if(s.rbegin(), s.rend(), std::not1(std::ptr_fun<int, int>(std::isspace))).base(), s.end());
return s;
}
// trim from both ends
static inline std::string &trim(std::string &s) {
return ltrim(rtrim(s));
}
vector<string> sep_string(const string & str, const string & sep);
#endif // _FILEUTIL_

Просмотреть файл

@ -1,448 +0,0 @@
// TODO: this is a dup; use the one in Include/ instead
//
// <copyright file="fileutil.h" company="Microsoft">
// Copyright (c) Microsoft Corporation. All rights reserved.
// </copyright>
//
#pragma once
#ifndef _FILEUTIL_
#define _FILEUTIL_
#include "basetypes.h"
#include <stdio.h>
#ifdef __WINDOWS__
#include <windows.h> // for mmreg.h and FILETIME
#include <mmreg.h>
#endif
#include <stdint.h>
using namespace std;
#define SAFE_CLOSE(f) (((f) == NULL) || (fcloseOrDie ((f)), (f) = NULL))
// ----------------------------------------------------------------------------
// fopenOrDie(): like fopen() but terminate with err msg in case of error.
// A pathname of "-" returns stdout or stdin, depending on mode, and it will
// change the binary mode if 'b' or 't' are given. If you use this, make sure
// not to fclose() such a handle.
// ----------------------------------------------------------------------------
FILE * fopenOrDie (const STRING & pathname, const char * mode);
FILE * fopenOrDie (const WSTRING & pathname, const wchar_t * mode);
#ifndef __unix__ // don't need binary/text distinction on unix
// ----------------------------------------------------------------------------
// fsetmode(): set mode to binary or text
// ----------------------------------------------------------------------------
void fsetmode (FILE * f, char type);
#endif
// ----------------------------------------------------------------------------
// freadOrDie(): like fread() but terminate with err msg in case of error
// ----------------------------------------------------------------------------
void freadOrDie (void * ptr, size_t size, size_t count, FILE * f);
void freadOrDie (void * ptr, size_t size, size_t count, const HANDLE f);
template<class _T>
void freadOrDie (_T & data, int num, FILE * f) // template for vector<>
{ data.resize (num); if (data.size() > 0) freadOrDie (&data[0], sizeof (data[0]), data.size(), f); }
template<class _T>
void freadOrDie (_T & data, size_t num, FILE * f) // template for vector<>
{ data.resize (num); if (data.size() > 0) freadOrDie (&data[0], sizeof (data[0]), data.size(), f); }
template<class _T>
void freadOrDie (_T & data, int num, const HANDLE f) // template for vector<>
{ data.resize (num); if (data.size() > 0) freadOrDie (&data[0], sizeof (data[0]), data.size(), f); }
template<class _T>
void freadOrDie (_T & data, size_t num, const HANDLE f) // template for vector<>
{ data.resize (num); if (data.size() > 0) freadOrDie (&data[0], sizeof (data[0]), data.size(), f); }
// ----------------------------------------------------------------------------
// fwriteOrDie(): like fwrite() but terminate with err msg in case of error
// ----------------------------------------------------------------------------
void fwriteOrDie (const void * ptr, size_t size, size_t count, FILE * f);
void fwriteOrDie (const void * ptr, size_t size, size_t count, const HANDLE f);
template<class _T>
void fwriteOrDie (const _T & data, FILE * f) // template for vector<>
{ if (data.size() > 0) fwriteOrDie (&data[0], sizeof (data[0]), data.size(), f); }
template<class _T>
void fwriteOrDie (const _T & data, const HANDLE f) // template for vector<>
{ if (data.size() > 0) fwriteOrDie (&data[0], sizeof (data[0]), data.size(), f); }
// ----------------------------------------------------------------------------
// fprintfOrDie(): like fprintf() but terminate with err msg in case of error
// ----------------------------------------------------------------------------
void fprintfOrDie (FILE * f, const char *format, ...);
// ----------------------------------------------------------------------------
// fcloseOrDie(): like fclose() but terminate with err msg in case of error
// not yet implemented, but we should
// ----------------------------------------------------------------------------
#define fcloseOrDie fclose
// ----------------------------------------------------------------------------
// fflushOrDie(): like fflush() but terminate with err msg in case of error
// ----------------------------------------------------------------------------
void fflushOrDie (FILE * f);
// ----------------------------------------------------------------------------
// filesize(): determine size of the file in bytes
// ----------------------------------------------------------------------------
size_t filesize (const wchar_t * pathname);
size_t filesize (FILE * f);
int64_t filesize64 (const wchar_t * pathname);
// ----------------------------------------------------------------------------
// fseekOrDie(),ftellOrDie(), fget/setpos(): seek functions with error handling
// ----------------------------------------------------------------------------
// 32-bit offsets only
long fseekOrDie (FILE * f, long offset, int mode = SEEK_SET);
#define ftellOrDie ftell
uint64_t fgetpos (FILE * f);
void fsetpos (FILE * f, uint64_t pos);
// ----------------------------------------------------------------------------
// unlinkOrDie(): unlink() with error handling
// ----------------------------------------------------------------------------
void unlinkOrDie (const std::string & pathname);
void unlinkOrDie (const std::wstring & pathname);
// ----------------------------------------------------------------------------
// renameOrDie(): rename() with error handling
// ----------------------------------------------------------------------------
void renameOrDie (const std::string & from, const std::string & to);
void renameOrDie (const std::wstring & from, const std::wstring & to);
// ----------------------------------------------------------------------------
// fexists(): test if a file exists
// ----------------------------------------------------------------------------
bool fexists (const char * pathname);
bool fexists (const wchar_t * pathname);
inline bool fexists (const std::string & pathname) { return fexists (pathname.c_str()); }
inline bool fexists (const std::wstring & pathname) { return fexists (pathname.c_str()); }
// ----------------------------------------------------------------------------
// funicode(): test if a file uses unicode
// ----------------------------------------------------------------------------
bool funicode (FILE * f);
// ----------------------------------------------------------------------------
// fskipspace(): skip space characters
// ----------------------------------------------------------------------------
void fskipspace (FILE * F);
// ----------------------------------------------------------------------------
// fgetline(): like fgets() but terminate with err msg in case of error;
// removes the newline character at the end (like gets()), returned buffer is
// always 0-terminated; has second version that returns an STL string instead
// fgetstring(): read a 0-terminated string (terminate if error)
// fgetword(): read a space-terminated token (terminate if error)
// fskipNewLine(): skip all white space until end of line incl. the newline
// ----------------------------------------------------------------------------
template<class CHAR> CHAR * fgetline (FILE * f, CHAR * buf, int size);
template<class CHAR, size_t n> CHAR * fgetline (FILE * f, CHAR (& buf)[n]) { return fgetline (f, buf, n); }
STRING fgetline (FILE * f);
WSTRING fgetlinew (FILE * f);
void fgetline (FILE * f, std::string & s, ARRAY<char> & buf);
void fgetline (FILE * f, std::wstring & s, ARRAY<char> & buf);
void fgetline (FILE * f, ARRAY<char> & buf);
void fgetline (FILE * f, ARRAY<wchar_t> & buf);
const char * fgetstring (FILE * f, char * buf, int size);
template<size_t n> const char * fgetstring (FILE * f, char (& buf)[n]) { return fgetstring (f, buf, n); }
const char * fgetstring (const HANDLE f, char * buf, int size);
template<size_t n> const char * fgetstring (const HANDLE f, char (& buf)[n]) { return fgetstring (f, buf, n); }
wstring fgetwstring (FILE * f);
const char * fgettoken (FILE * f, char * buf, int size);
template<size_t n> const char * fgettoken (FILE * f, char (& buf)[n]) { return fgettoken (f, buf, n); }
STRING fgettoken (FILE * f);
void fskipNewline (FILE * f);
// ----------------------------------------------------------------------------
// fputstring(): write a 0-terminated string (terminate if error)
// ----------------------------------------------------------------------------
void fputstring (FILE * f, const char *);
void fputstring (const HANDLE f, const char * str);
void fputstring (FILE * f, const std::string &);
void fputstring (FILE * f, const wchar_t *);
void fputstring (FILE * f, const std::wstring &);
// ----------------------------------------------------------------------------
// fgetTag(): read a 4-byte tag & return as a string
// ----------------------------------------------------------------------------
STRING fgetTag (FILE * f);
// ----------------------------------------------------------------------------
// fcheckTag(): read a 4-byte tag & verify it; terminate if wrong tag
// ----------------------------------------------------------------------------
void fcheckTag (FILE * f, const char * expectedTag);
void fcheckTag (const HANDLE f, const char * expectedTag);
void fcheckTag_ascii (FILE * f, const STRING & expectedTag);
// ----------------------------------------------------------------------------
// fcompareTag(): compare two tags; terminate if wrong tag
// ----------------------------------------------------------------------------
void fcompareTag (const STRING & readTag, const STRING & expectedTag);
// ----------------------------------------------------------------------------
// fputTag(): write a 4-byte tag
// ----------------------------------------------------------------------------
void fputTag (FILE * f, const char * tag);
void fputTag(const HANDLE f, const char * tag);
// ----------------------------------------------------------------------------
// fskipstring(): skip a 0-terminated string, such as a pad string
// ----------------------------------------------------------------------------
void fskipstring (FILE * f);
// ----------------------------------------------------------------------------
// fpad(): write a 0-terminated string to pad file to a n-byte boundary
// ----------------------------------------------------------------------------
void fpad (FILE * f, int n);
// ----------------------------------------------------------------------------
// fgetbyte(): read a byte value
// ----------------------------------------------------------------------------
char fgetbyte (FILE * f);
// ----------------------------------------------------------------------------
// fgetshort(): read a short value
// ----------------------------------------------------------------------------
short fgetshort (FILE * f);
short fgetshort_bigendian (FILE * f);
// ----------------------------------------------------------------------------
// fgetint24(): read a 3-byte (24-bit) int value
// ----------------------------------------------------------------------------
int fgetint24 (FILE * f);
// ----------------------------------------------------------------------------
// fgetint(): read an int value
// ----------------------------------------------------------------------------
int fgetint (FILE * f);
int fgetint (const HANDLE f);
int fgetint_bigendian (FILE * f);
int fgetint_ascii (FILE * f);
// ----------------------------------------------------------------------------
// fgetfloat(): read a float value
// ----------------------------------------------------------------------------
float fgetfloat (FILE * f);
float fgetfloat_bigendian (FILE * f);
float fgetfloat_ascii (FILE * f);
// ----------------------------------------------------------------------------
// fgetdouble(): read a double value
// ----------------------------------------------------------------------------
double fgetdouble (FILE * f);
// ----------------------------------------------------------------------------
// fgetwav(): read an entire .wav file
// ----------------------------------------------------------------------------
void fgetwav (FILE * f, ARRAY<short> & wav, int & sampleRate);
void fgetwav (const wstring & fn, ARRAY<short> & wav, int & sampleRate);
// ----------------------------------------------------------------------------
// fputwav(): save data into a .wav file
// ----------------------------------------------------------------------------
void fputwav (FILE * f, const vector<short> & wav, int sampleRate, int nChannels = 1);
void fputwav (const wstring & fn, const vector<short> & wav, int sampleRate, int nChannels = 1);
// ----------------------------------------------------------------------------
// fputbyte(): write a byte value
// ----------------------------------------------------------------------------
void fputbyte (FILE * f, char val);
// ----------------------------------------------------------------------------
// fputshort(): write a short value
// ----------------------------------------------------------------------------
void fputshort (FILE * f, short val);
// ----------------------------------------------------------------------------
// fputint24(): write a 3-byte (24-bit) int value
// ----------------------------------------------------------------------------
void fputint24 (FILE * f, int v);
// ----------------------------------------------------------------------------
// fputint(): write an int value
// ----------------------------------------------------------------------------
void fputint (FILE * f, int val);
void fputint (const HANDLE f, int v);
// ----------------------------------------------------------------------------
// fputfloat(): write a float value
// ----------------------------------------------------------------------------
void fputfloat (FILE * f, float val);
// ----------------------------------------------------------------------------
// fputdouble(): write a double value
// ----------------------------------------------------------------------------
void fputdouble (FILE * f, double val);
// ----------------------------------------------------------------------------
// fputfile(): write a binary block or a string as a file
// ----------------------------------------------------------------------------
void fputfile (const WSTRING & pathname, const ARRAY<char> & buffer);
void fputfile (const WSTRING & pathname, const std::wstring & string);
void fputfile (const WSTRING & pathname, const std::string & string);
// ----------------------------------------------------------------------------
// fgetfile(): load a file as a binary block
// ----------------------------------------------------------------------------
void fgetfile (const WSTRING & pathname, ARRAY<char> & buffer);
void fgetfile (FILE * f, ARRAY<char> & buffer);
namespace msra { namespace files {
void fgetfilelines (const std::wstring & pathname, vector<char> & readbuffer, std::vector<std::string> & lines);
static inline std::vector<std::string> fgetfilelines (const std::wstring & pathname) { vector<char> buffer; std::vector<std::string> lines; fgetfilelines (pathname, buffer, lines); return lines; }
vector<char*> fgetfilelines (const wstring & pathname, vector<char> & readbuffer);
};};
// ----------------------------------------------------------------------------
// getfiletime(), setfiletime(): access modification time
// ----------------------------------------------------------------------------
bool getfiletime (const std::wstring & path, FILETIME & time);
void setfiletime (const std::wstring & path, const FILETIME & time);
// ----------------------------------------------------------------------------
// expand_wildcards() -- expand a path with wildcards (also intermediate ones)
// ----------------------------------------------------------------------------
void expand_wildcards (const wstring & path, vector<wstring> & paths);
// ----------------------------------------------------------------------------
// make_intermediate_dirs() -- make all intermediate dirs on a path
// ----------------------------------------------------------------------------
namespace msra { namespace files {
void make_intermediate_dirs (const wstring & filepath);
};};
// ----------------------------------------------------------------------------
// fuptodate() -- test whether an output file is at least as new as an input file
// ----------------------------------------------------------------------------
namespace msra { namespace files {
bool fuptodate (const wstring & target, const wstring & input, bool inputrequired = true);
};};
// ----------------------------------------------------------------------------
// simple support for WAV file I/O
// ----------------------------------------------------------------------------
typedef struct wavehder{
char riffchar[4];
unsigned int RiffLength;
char wavechar[8];
unsigned int FmtLength;
signed short wFormatTag;
signed short nChannels;
unsigned int nSamplesPerSec;
unsigned int nAvgBytesPerSec;
signed short nBlockAlign;
signed short wBitsPerSample;
char datachar[4];
unsigned int DataLength;
private:
void prepareRest (int SampleCount);
public:
void prepare (unsigned int Fs, int Bits, int Channels, int SampleCount);
void prepare (const WAVEFORMATEX & wfx, int SampleCount);
unsigned int read (FILE * f, signed short & wRealFormatTag, int & bytesPerSample);
void write (FILE * f);
static void update (FILE * f);
} WAVEHEADER;
// ----------------------------------------------------------------------------
// fgetwfx(), fputwfx(): I/O of wave file headers only
// ----------------------------------------------------------------------------
unsigned int fgetwfx (FILE *f, WAVEFORMATEX & wfx);
void fputwfx (FILE *f, const WAVEFORMATEX & wfx, unsigned int numSamples);
// ----------------------------------------------------------------------------
// fgetraw(): read data of .wav file, and separate data of multiple channels.
// For example, data[i][j]: i is channel index, 0 means the first
// channel. j is sample index.
// ----------------------------------------------------------------------------
void fgetraw (FILE *f,ARRAY< ARRAY<short> > & data,const WAVEHEADER & wavhd);
// ----------------------------------------------------------------------------
// temp functions -- clean these up
// ----------------------------------------------------------------------------
// split a pathname into directory and filename
static inline void splitpath (const wstring & path, wstring & dir, wstring & file)
{
size_t pos = path.find_last_of (L"\\:/"); // DOS drives, UNIX, Windows
if (pos == path.npos) // no directory found
{
dir.clear();
file = path;
}
else
{
dir = path.substr (0, pos);
file = path.substr (pos +1);
}
}
// test if a pathname is a relative path
// A relative path is one that can be appended to a directory.
// Drive-relative paths, such as D:file, are considered non-relative.
static inline bool relpath (const wchar_t * path)
{ // this is a wild collection of pathname conventions in Windows
if (path[0] == '/' || path[0] == '\\') // e.g. \WINDOWS
return false;
if (path[0] && path[1] == ':') // drive syntax
return false;
// ... TODO: handle long NT paths
return true; // all others
}
template<class CHAR>
static inline bool relpath (const std::basic_string<CHAR> & s) { return relpath (s.c_str()); }
#endif // _FILEUTIL_

Просмотреть файл

@ -1,951 +0,0 @@
//
// <copyright file="htkfeatio.h" company="Microsoft">
// Copyright (c) Microsoft Corporation. All rights reserved.
// </copyright>
//
// htkfeatio.h -- helper for I/O of HTK feature files
#pragma once
#include "basetypes.h"
#include "fileutil.h"
#include "simple_checked_arrays.h"
#include <string>
#include <regex>
#include <set>
#include <hash_map>
#include <stdint.h>
#include <limits.h>
#include <wchar.h>
namespace msra { namespace asr {
// ===========================================================================
// htkfeatio -- common base class for reading and writing HTK feature files
// ===========================================================================
class htkfeatio
{
protected:
auto_file_ptr f;
wstring physicalpath; // path of this file
bool needbyteswapping; // need to swap the bytes?
string featkind; // HTK feature-kind string
size_t featdim; // feature dimension
unsigned int featperiod; // sampling period
// note that by default we assume byte swapping (seems to be HTK default)
htkfeatio() : needbyteswapping (true), featdim (0), featperiod (0) {}
// set the feature kind variables --if already set then validate that they are the same
// Path is only for error message.
void setkind (string kind, size_t dim, unsigned int period, const wstring & path)
{
if (featkind.empty()) // not set yet: just memorize them
{
assert (featdim == 0 && featperiod == 0);
featkind = kind;
featdim = dim;
featperiod = period;
}
else // set already: check if consistent
{
if (featkind != kind || featdim != dim || featperiod != period)
throw std::runtime_error (msra::strfun::strprintf ("setkind: inconsistent feature kind for file '%S'", path.c_str()));
}
}
static short swapshort (short v) throw()
{
const unsigned char * b = (const unsigned char *) &v;
return (short) ((b[0] << 8) + b[1]);
}
static int swapint (int v) throw()
{
const unsigned char * b = (const unsigned char *) &v;
return (int) (((((b[0] << 8) + b[1]) << 8) + b[2]) << 8) + b[3];
}
struct fileheader
{
int nsamples;
int sampperiod;
short sampsize;
short sampkind;
void read (FILE * f)
{
nsamples = fgetint (f);
sampperiod = fgetint (f);
sampsize = fgetshort (f);
sampkind = fgetshort (f);
}
// read header of idx feature cach
void idxRead (FILE * f)
{
int magic = swapint(fgetint (f));
if (magic != 2051)
throw std::runtime_error ("reading idx feature cache header: invalid magic");
nsamples = swapint(fgetint(f));
sampperiod = 0;
sampkind = (short)9; //user type
int nRows = swapint(fgetint(f));
int nCols = swapint(fgetint(f));
sampsize = (short) (nRows * nCols); // features are stored as bytes;
}
void write (FILE * f)
{
fputint (f, nsamples);
fputint (f, sampperiod);
fputshort (f, sampsize);
fputshort (f, sampkind);
}
void byteswap()
{
nsamples = swapint (nsamples);
sampperiod = swapint (sampperiod);
sampsize = swapshort (sampsize);
sampkind = swapshort (sampkind);
}
};
static const int BASEMASK = 077;
static const int PLP = 11;
static const int MFCC = 6;
static const int FBANK = 7;
static const int USER = 9;
static const int FESTREAM = 12;
static const int HASENERGY = 0100; // _E log energy included
static const int HASNULLE = 0200; // _N absolute energy suppressed
static const int HASDELTA = 0400; // _D delta coef appended
static const int HASACCS = 01000; // _A acceleration coefs appended
static const int HASCOMPX = 02000; // _C is compressed
static const int HASZEROM = 04000; // _Z zero meaned
static const int HASCRCC = 010000; // _K has CRC check
static const int HASZEROC = 020000; // _0 0'th Cepstra included
static const int HASVQ = 040000; // _V has VQ index attached
static const int HASTHIRD = 0100000; // _T has Delta-Delta-Delta index attached
};
// ===========================================================================
// htkfeatwriter -- write HTK feature file
// This is designed to write a single file only (no archive mode support).
// ===========================================================================
class htkfeatwriter : protected htkfeatio
{
size_t curframe;
vector<float> tmp;
public:
short parsekind (const string & str)
{
vector<string> params = msra::strfun::split (str, ";");
if (params.empty())
throw std::runtime_error ("parsekind: invalid param kind string");
vector<string> parts = msra::strfun::split (params[0], "_");
// map base kind
short sampkind;
string basekind = parts[0];
if (basekind == "PLP") sampkind = PLP;
else if (basekind == "MFCC") sampkind = MFCC;
else if (basekind == "FBANK") sampkind = FBANK;
else if (basekind == "USER") sampkind = USER;
else throw std::runtime_error ("parsekind: unsupported param base kind");
// map qualifiers
for (size_t i = 1; i < parts.size(); i++)
{
string opt = parts[i];
if (opt.length() != 1)
throw std::runtime_error ("parsekind: invalid param kind string");
switch (opt[0])
{
case 'E': sampkind |= HASENERGY; break;
case 'D': sampkind |= HASDELTA; break;
case 'N': sampkind |= HASNULLE; break;
case 'A': sampkind |= HASACCS; break;
case 'T': sampkind |= HASTHIRD; break;
case 'Z': sampkind |= HASZEROM; break;
case '0': sampkind |= HASZEROC; break;
default: throw std::runtime_error ("parsekind: invalid qualifier in param kind string");
}
}
return sampkind;
}
public:
// open the file for writing
htkfeatwriter (wstring path, string kind, size_t dim, unsigned int period)
{
setkind (kind, dim, period, path);
// write header
fileheader H;
H.nsamples = 0; // unknown for now, updated in close()
H.sampperiod = period;
const int bytesPerValue = sizeof (float); // we do not support compression for now
H.sampsize = (short) featdim * bytesPerValue;
H.sampkind = parsekind (kind);
if (needbyteswapping)
H.byteswap();
f = fopenOrDie (path, L"wbS");
H.write (f);
curframe = 0;
}
// write a frame
void write (const vector<float> & v)
{
if (v.size() != featdim)
throw std::logic_error ("htkfeatwriter: inconsistent feature dimension");
if (needbyteswapping)
{
tmp.resize (v.size());
foreach_index (k, v) tmp[k] = v[k];
msra::util::byteswap (tmp);
fwriteOrDie (tmp, f);
}
else
fwriteOrDie (v, f);
curframe++;
}
// finish
// This updates the header.
// BUGBUG: need to implement safe-save semantics! Otherwise won't work reliably with -make mode.
// ... e.g. set DeleteOnClose temporarily, and clear at the end?
void close (size_t numframes)
{
if (curframe != numframes)
throw std::logic_error ("htkfeatwriter: inconsistent number of frames passed to close()");
fflushOrDie (f);
// now implant the length field; it's at offset 0
int nSamplesFile = (int) numframes;
if (needbyteswapping)
nSamplesFile = swapint (nSamplesFile);
fseekOrDie (f, 0);
fputint (f, nSamplesFile);
fflushOrDie (f);
f = NULL; // this triggers an fclose() on auto_file_ptr
}
// read an entire utterance into a matrix
// Matrix type needs to have operator(i,j) and resize(n,m).
// We write to a tmp file first to ensure we don't leave broken files that would confuse make mode.
template<class MATRIX> static void write (const wstring & path, const string & kindstr, unsigned int period, const MATRIX & feat)
{
wstring tmppath = path + L"$$"; // tmp path for make-mode compliant
unlinkOrDie (path); // delete if old file is already there
// write it out
size_t featdim = feat.rows();
size_t numframes = feat.cols();
vector<float> v (featdim);
htkfeatwriter W (tmppath, kindstr, feat.rows(), period);
#ifdef SAMPLING_EXPERIMENT
for (size_t i = 0; i < numframes; i++)
{
foreach_index (k, v)
{
float val = feat(k,i) - logf((float) SAMPLING_EXPERIMENT);
if (i % SAMPLING_EXPERIMENT == 0)
v[k] = val;
else
v[k] += (float) (log (1 + exp (val - v[k]))); // log add
}
if (i % SAMPLING_EXPERIMENT == SAMPLING_EXPERIMENT -1)
W.write (v);
}
#else
for (size_t i = 0; i < numframes; i++)
{
foreach_index (k, v)
v[k] = feat(k,i);
W.write (v);
}
#endif
#ifdef SAMPLING_EXPERIMENT
W.close (numframes / SAMPLING_EXPERIMENT);
#else
W.close (numframes);
#endif
// rename to final destination
// (This would only fail in strange circumstances such as accidental multiple processes writing to the same file.)
// renameOrDie (tmppath, path);
}
};
// ===========================================================================
// htkfeatreader -- read HTK feature file, with archive support
//
// To support archives, one instance of this can (and is supposed to) be used
// repeatedly. All feat files read on the same instance are validated to have
// the same feature kind.
//
// For archives, this caches the last used file handle, in expectation that most reads
// are sequential anyway. In conjunction with a big buffer, this makes a huge difference.
// ===========================================================================
class htkfeatreader : protected htkfeatio
{
// information on current file
// File handle and feature type information is stored in the underlying htkfeatio object.
size_t physicalframes; // total number of frames in physical file
//TODO make this nicer
bool isidxformat; // support reading of features in idxformat as well (it's a hack, but different format's are not supported yet)
uint64_t physicaldatastart; // byte offset of first data byte
size_t vecbytesize; // size of one vector in bytes
bool addEnergy; // add in energy as data is read (will all have zero values)
bool compressed; // is compressed to 16-bit values
bool hascrcc; // need to skip crcc
vector<float> a, b; // for decompression
vector<short> tmp; // for decompression
vector<unsigned char> tmpByteVector; // for decompression of idx files
size_t curframe; // current # samples read so far
size_t numframes; // number of samples for current logical file
size_t energyElements; // how many energy elements to add if addEnergy is true
public:
// parser for complex a=b[s,e] syntax
struct parsedpath
{
protected:
friend class htkfeatreader;
bool isarchive; // true if archive (range specified)
bool isidxformat; // support reading of features in idxformat as well (it's a hack, but different format's are not supported yet)
wstring xpath; // original full path specification as passed to constructor (for error messages)
wstring logicalpath; // virtual path that this file should be understood to belong to
wstring archivepath; // physical path of archive file
size_t s, e; // first and last frame inside the archive file; (0, INT_MAX) if not given
void malformed() const { throw std::runtime_error (msra::strfun::strprintf ("parsedpath: malformed path '%S'", xpath.c_str())); }
// consume and return up to 'delim'; remove from 'input' (we try to avoid C++0x here for VS 2008 compat)
wstring consume (wstring & input, const wchar_t * delim)
{
vector<wstring> parts = msra::strfun::split (input, delim); // (not very efficient, but does not matter here)
if (parts.size() == 1) input.clear(); // not found: consume to end
else input = parts[1]; // found: break at delimiter
return parts[0];
}
public:
// constructor parses a=b[s,e] syntax and fills in the file
// Can be used implicitly e.g. by passing a string to open().
parsedpath (wstring xpath) : xpath (xpath)
{
// parse out logical path
logicalpath = consume (xpath, L"=");
isidxformat = false;
if (xpath.empty()) // no '=' detected: pass entire file (it's not an archive)
{
archivepath = logicalpath;
s = 0;
e = INT_MAX;
isarchive = false;
// check for "-ubyte" suffix in path name => it is an idx file
wstring ubyte(L"-ubyte");
size_t pos = archivepath.size() >= ubyte.size() ? archivepath.size() - ubyte.size() : 0;
wstring suffix = archivepath.substr(pos , ubyte.size());
isidxformat = ubyte == suffix;
}
else // a=b[s,e] syntax detected
{
archivepath = consume (xpath, L"[");
if (xpath.empty()) // actually it's only a=b
{
s = 0;
e = INT_MAX;
isarchive = false;
}
else
{
s = msra::strfun::toint (consume (xpath, L","));
if (xpath.empty()) malformed();
e = msra::strfun::toint (consume (xpath, L"]"));
if (!xpath.empty()) malformed();
isarchive = true;
}
}
}
// get the physical path for 'make' test
const wstring & physicallocation() const { return archivepath; }
// casting to wstring yields the logical path
operator const wstring & () const { return logicalpath; }
// get duration in frames
size_t numframes() const
{
if (!isarchive)
throw runtime_error ("parsedpath: this mode requires an input script with start and end frames given");
return e - s + 1;
}
};
private:
// open the physical HTK file
// This is different from the logical (virtual) path name in the case of an archive.
void openphysical (const parsedpath & ppath)
{
wstring physpath = ppath.physicallocation();
//auto_file_ptr f = fopenOrDie (physpath, L"rbS");
auto_file_ptr f (fopenOrDie (physpath, L"rb")); // removed 'S' for now, as we mostly run local anyway, and this will speed up debugging
// read the header (12 bytes for htk feature files)
fileheader H;
isidxformat = ppath.isidxformat;
if (!isidxformat)
H.read (f);
else // read header of idxfile
H.idxRead (f);
// take a guess as to whether we need byte swapping or not
bool needbyteswapping = ((unsigned int) swapint (H.sampperiod) < (unsigned int) H.sampperiod);
if (needbyteswapping)
H.byteswap();
// interpret sampkind
int basekind = H.sampkind & BASEMASK;
string kind;
switch (basekind)
{
case PLP: kind = "PLP"; break;
case MFCC: kind = "MFCC"; break;
case FBANK: kind = "FBANK"; break;
case USER: kind = "USER"; break;
case FESTREAM: kind = "USER"; break; // we return this as USER type (with guid)
default: throw std::runtime_error ("htkfeatreader:unsupported feature kind");
}
// add qualifiers
if (H.sampkind & HASENERGY) kind += "_E";
if (H.sampkind & HASDELTA) kind += "_D";
if (H.sampkind & HASNULLE) kind += "_N";
if (H.sampkind & HASACCS) kind += "_A";
if (H.sampkind & HASTHIRD) kind += "_T";
bool compressed = (H.sampkind & HASCOMPX) != 0;
bool hascrcc = (H.sampkind & HASCRCC) != 0;
if (H.sampkind & HASZEROM) kind += "_Z";
if (H.sampkind & HASZEROC) kind += "_0";
if (H.sampkind & HASVQ) throw std::runtime_error ("htkfeatreader:we do not support VQ");
// skip additional GUID in FESTREAM features
if (H.sampkind == FESTREAM)
{ // ... note: untested
unsigned char guid[16];
freadOrDie (&guid, sizeof (guid), 1, f);
kind += ";guid=";
for (int i = 0; i < sizeof (guid)/sizeof (*guid); i++)
kind += msra::strfun::strprintf ("%02x", guid[i]);
}
// other checks
size_t bytesPerValue = isidxformat ? 1 : (compressed ? sizeof (short) : sizeof (float));
if (H.sampsize % bytesPerValue != 0) throw std::runtime_error ("htkfeatreader:sample size not multiple of dimension");
size_t dim = H.sampsize / bytesPerValue;
// read the values for decompressing
vector<float> a, b;
if (compressed)
{
freadOrDie (a, dim, f);
freadOrDie (b, dim, f);
H.nsamples -= 4; // these are counted as 4 frames--that's the space they use
if (needbyteswapping) { msra::util::byteswap (a); msra::util::byteswap (b); }
}
// done: swap it in
int64_t bytepos = fgetpos (f);
setkind (kind, dim, H.sampperiod, ppath); // this checks consistency
this->physicalpath.swap (physpath);
this->physicaldatastart = bytepos;
this->physicalframes = H.nsamples;
this->f.swap (f); // note: this will get the previous f auto-closed at the end of this function
this->needbyteswapping = needbyteswapping;
this->compressed = compressed;
this->a.swap (a);
this->b.swap (b);
this->vecbytesize = H.sampsize;
this->hascrcc = hascrcc;
}
void close() // force close the open file --use this in case of read failure
{
f = NULL; // assigning a new FILE* to f will close the old FILE* if any
physicalpath.clear();
}
public:
htkfeatreader() {addEnergy = false; energyElements = 0;}
// helper to create a parsed-path object
// const auto path = parse (xpath)
parsedpath parse (const wstring & xpath) { return parsedpath (xpath); }
// read a feature file
// Returns number of frames in that file.
// This understands the more complex syntax a=b[s,e] and optimizes a little
size_t open (const parsedpath & ppath)
{
// do not reopen the file if it is the same; use fsetpos() instead
if (f == NULL || ppath.physicallocation() != physicalpath)
openphysical (ppath);
if (ppath.isarchive) // reading a sub-range from an archive
{
if (ppath.s > ppath.e)
throw std::runtime_error (msra::strfun::strprintf ("open: start frame > end frame in '%S'", ppath.e, physicalframes, ppath.xpath.c_str()));
if (ppath.e >= physicalframes)
throw std::runtime_error (msra::strfun::strprintf ("open: end frame exceeds archive's total number of frames %d in '%S'", physicalframes, ppath.xpath.c_str()));
int64_t dataoffset = physicaldatastart + ppath.s * vecbytesize;
fsetpos (f, dataoffset); // we assume fsetpos(), which is our own, is smart to not flush the read buffer
curframe = 0;
numframes = ppath.e + 1 - ppath.s;
}
else // reading a full file
{
curframe = 0;
numframes = physicalframes;
assert (fgetpos (f) == physicaldatastart);
}
return numframes;
}
// get dimension and type information for a feature file
// This will alter the state of this object in that it opens the file. It is efficient to read it right afterwards
void getinfo (const parsedpath & ppath, string & featkind, size_t & featdim, unsigned int & featperiod)
{
open (ppath);
featkind = this->featkind;
featdim = this->featdim;
featperiod = this->featperiod;
}
// called to add energy as we read
void AddEnergy(size_t energyElements)
{
this->energyElements = energyElements;
this->addEnergy = energyElements != 0;
}
const string & getfeattype() const { return featkind; }
operator bool() const { return curframe < numframes; }
// read a vector from the open file
void read (std::vector<float> & v)
{
if (curframe >= numframes) throw std::runtime_error ("htkfeatreader:attempted to read beyond end");
if (!compressed && !isidxformat) // not compressed--the easy one
{
freadOrDie (v, featdim, f);
if (needbyteswapping) msra::util::byteswap (v);
}
else if (isidxformat)
{
// read into temp vector
freadOrDie (tmpByteVector, featdim, f);
v.resize (featdim);
foreach_index (k, v)
v[k] = (float) tmpByteVector[k];
}
else // need to decompress
{
// read into temp vector
freadOrDie (tmp, featdim, f);
if (needbyteswapping) msra::util::byteswap (tmp);
// 'decompress' it
v.resize (tmp.size());
foreach_index (k, v)
v[k] = (tmp[k] + b[k]) / a[k];
}
curframe++;
}
// read a sequence of vectors from the open file into a range of frames [ts,te)
template<class MATRIX> void read (MATRIX & feat, size_t ts, size_t te)
{
// read vectors from file and push to our target structure
vector<float> v(featdim+energyElements);
for (size_t t = ts; t < te; t++)
{
read (v);
// add the energy elements (all zero) if needed
if (addEnergy)
{
// we add the energy elements at the end of each section of features, (features, delta, delta-delta)
size_t posIncrement = featdim/energyElements;
size_t pos = posIncrement;
for (size_t i=0;i < energyElements;i++,pos+=posIncrement)
{
auto iter = v.begin() + pos + i;
v.insert(iter,0.0f);
}
}
foreach_index (k, v)
feat(k,t) = v[k];
}
}
// read an entire utterance into an already allocated matrix
// Matrix type needs to have operator(i,j)
template<class MATRIX> void read (const parsedpath & ppath, const string & kindstr, const unsigned int period, MATRIX & feat)
{
// open the file and check dimensions
size_t numframes = open (ppath);
if (feat.cols() != numframes || feat.rows() != featdim)
throw std::logic_error ("read: stripe read called with wrong dimensions");
if (kindstr != featkind || period != featperiod)
throw std::logic_error ("read: attempting to mixing different feature kinds");
// read vectors from file and push to our target structure
try { read (feat, 0, numframes); } catch (...) { close(); throw; }
}
// read an entire utterance into a virgen, allocatable matrix
// Matrix type needs to have operator(i,j) and resize(n,m)
template<class MATRIX> void read (const parsedpath & ppath, string & kindstr, unsigned int & period, MATRIX & feat)
{
// get the file
size_t numframes = open (ppath);
feat.resize (featdim+energyElements, numframes); // result matrix--columns are features
// read vectors from file and push to our target structure
try { read (feat, 0, numframes); } catch (...) { close(); throw; }
// return file info
kindstr = featkind;
period = featperiod;
}
};
struct htkmlfentry
{
unsigned int firstframe; // range [firstframe,firstframe+numframes)
unsigned int numframes;
//unsigned short classid; // numeric state id
unsigned int classid; // numeric state id - mseltzer changed from ushort to uint for untied cd phones > 2^16
private:
// verify and save data
void setdata (size_t ts, size_t te, size_t uid)
{
if (te < ts) throw std::runtime_error ("htkmlfentry: end time below start time??");
// save
firstframe = (unsigned int) ts;
numframes = (unsigned int) (te - ts);
classid = (unsigned int) uid;
// check for numeric overflow
if (firstframe != ts || firstframe + numframes != te || classid != uid)
throw std::runtime_error ("htkmlfentry: not enough bits for one of the values");
}
// parse the time range
// There are two formats:
// - original HTK
// - Dong's hacked format: ts te senonename senoneid
// We distinguish
static void parseframerange (const vector<char*> & toks, size_t & ts, size_t & te, const double htkTimeToFrame)
{
const double maxFrameNumber = htkTimeToFrame / 2.0; // if frame number is greater than this we assume it is time instead of frame
double rts = msra::strfun::todouble (toks[0]);
double rte = msra::strfun::todouble (toks[1]);
if (rte > maxFrameNumber) // convert time to frame
{
ts = (size_t) (rts/htkTimeToFrame + 0.5); // get start frame
te = (size_t) (rte/htkTimeToFrame + 0.5); // get end frame
}
else
{
ts = (size_t)(rts);
te = (size_t)(rte);
}
}
public:
// parse format with original HTK state align MLF format and state list
void parsewithstatelist (const vector<char*> & toks, const hash_map<std::string, size_t> & statelisthash, const double htkTimeToFrame)
{
size_t ts, te;
parseframerange (toks, ts, te, htkTimeToFrame);
auto iter = statelisthash.find (toks[2]);
if (iter == statelisthash.end())
throw std::runtime_error (msra::strfun::strprintf ("htkmlfentry: state %s not found in statelist", toks[2]));
const size_t uid = iter->second; // get state index
setdata (ts, te, uid);
}
// ... note: this will be too simplistic for parsing more complex MLF formats. Fix when needed.
// add support so that it can handle conditions where time instead of frame numer is used.
void parse (const vector<char*> & toks, const double htkTimeToFrame)
{
if (toks.size() != 4) throw std::runtime_error ("htkmlfentry: currently we only support 4-column format");
size_t ts, te;
parseframerange (toks, ts, te, htkTimeToFrame);
size_t uid = msra::strfun::toint (toks[3]);
setdata(ts, te, uid);
}
};
template<class ENTRY, class WORDSEQUENCE>
class htkmlfreader : public map<wstring,vector<ENTRY>> // [key][i] the data
{
wstring curpath; // for error messages
hash_map<std::string, size_t> statelistmap; // for state <=> index
map<wstring,WORDSEQUENCE> wordsequences; // [key] word sequences (if we are building word entries as well, for MMI)
void strtok (char * s, const char * delim, vector<char*> & toks)
{
toks.resize (0);
char * context = nullptr;
for (char * p = strtok_s (s, delim, &context); p; p = strtok_s (NULL, delim, &context))
toks.push_back (p);
}
void malformed (string what)
{
throw std::runtime_error (msra::strfun::strprintf ("htkmlfreader: %s in '%S'", what.c_str(), curpath.c_str()));
}
vector<char*> readlines (const wstring & path, vector<char> & buffer)
{
// load it into RAM in one huge chunk
auto_file_ptr f (fopenOrDie (path, L"rb"));
size_t len = filesize (f);
buffer.reserve (len +1);
freadOrDie (buffer, len, f);
buffer.push_back (0); // this makes it a proper C string
// parse into lines
vector<char *> lines;
lines.reserve (len / 20);
strtok (&buffer[0], "\r\n", lines);
return lines;
}
// determine mlf entry lines range
// lines range: [s,e)
size_t getnextmlfstart (vector<char*> & lines, size_t s)
{
// determine lines range
size_t e;
for (e = s ; ; e++)
{
if (e >= lines.size()) malformed ("unexpected end in mid-utterance");
char * ll = lines[e];
if (ll[0] == '.' && ll[1] == 0) // end delimiter: a single dot on a line
break;
}
return (e + 1);
// lines range: [s,e)
}
template<typename WORDSYMBOLTABLE, typename UNITSYMBOLTABLE>
void parseentry (vector<char*> & lines, size_t & line, const set<wstring> & restricttokeys,
const WORDSYMBOLTABLE * wordmap, const UNITSYMBOLTABLE * unitmap, vector<typename WORDSEQUENCE::word> & wordseqbuffer, vector<typename WORDSEQUENCE::aligninfo> & alignseqbuffer,
const double htkTimeToFrame)
{
assert (line < lines.size());
string filename = lines[line++];
while (filename == "#!MLF!#") // skip embedded duplicate MLF headers (so user can 'cat' MLFs)
filename = lines[line++];
// some mlf file have write errors, so skip malformed entry
if (filename.length() < 3 || filename[0] != '"' || filename[filename.length()-1] != '"')
{
fprintf (stderr, "warning: filename entry (%s)\n", filename.c_str());
size_t s = line;
line = getnextmlfstart (lines, s);
fprintf (stderr, "skip current mlf entry form line (%lu) until line (%lu).\n", s, line);
return;
}
//fprintf (stderr,"start parse %s\n", filename.c_str());
filename = filename.substr (1, filename.length() -2); // strip quotes
if (filename.find ("*/") == 0) filename = filename.substr (2);
#ifdef _WIN32
wstring key = msra::strfun::utf16 (regex_replace (filename, regex ("\\.[^\\.\\\\/:]*$", std::regex_constants::extended), string())); // delete extension (or not if none)
#endif
#ifdef __unix__
wstring key = msra::strfun::utf16(removeExtension(basename(filename))); // note that c++ 4.8 is incomplete for supporting regex
#endif
//fwprintf (stderr,L"after parse %S\n",key.c_str());
// determine lines range
size_t s = line;
line = getnextmlfstart (lines, line);
size_t e = line - 1;
// lines range: [s,e)
// don't parse unused entries (this is supposed to be used for very small debugging setups with huge MLFs)
if (!restricttokeys.empty() && restricttokeys.find (key) == restricttokeys.end())
return;
vector<ENTRY> & entries = (*this)[key]; // this creates a new entry
if (!entries.empty()) malformed (msra::strfun::strprintf ("duplicate entry '%S'", key.c_str()));
entries.resize (e-s);
wordseqbuffer.resize (0);
alignseqbuffer.resize (0);
vector<char*> toks;
for (size_t i = s; i < e; i++)
{
strtok (lines[i], " \t", toks);
if (statelistmap.size() == 0)
entries[i-s].parse (toks, htkTimeToFrame);
else
entries[i-s].parsewithstatelist (toks, statelistmap, htkTimeToFrame);
// if we also read word entries, do it here
if (wordmap)
{
if (toks.size() > 6/*word entry are in this column*/)
{
const char * w = toks[6]; // the word name
int wid = (*wordmap)[w]; // map to word id --may be -1 for unseen words in the transcript (word list typically comes from a test LM)
size_t wordindex = (wid == -1) ? WORDSEQUENCE::word::unknownwordindex : (size_t) wid;
wordseqbuffer.push_back (typename WORDSEQUENCE::word (wordindex, entries[i-s].firstframe, alignseqbuffer.size()));
}
if (unitmap)
{
if (toks.size() > 4)
{
const char * u = toks[4]; // the triphone name
auto iter = unitmap->find (u); // map to unit id
if (iter == unitmap->end())
throw std::runtime_error (string ("parseentry: unknown unit ") + u + " in utterance " + strfun::utf8 (key));
const size_t uid = iter->second;
alignseqbuffer.push_back (typename WORDSEQUENCE::aligninfo (uid, 0/*#frames--we accumulate*/));
}
if (alignseqbuffer.empty())
throw std::runtime_error ("parseentry: lonely senone entry at start without phone/word entry found, for utterance " + strfun::utf8 (key));
alignseqbuffer.back().frames += entries[i-s].numframes; // (we do not have an overflow check here, but should...)
}
}
}
if (wordmap) // if reading word sequences as well (for MMI), then record it (in a separate map)
{
if (!entries.empty() && wordseqbuffer.empty())
throw std::runtime_error ("parseentry: got state alignment but no word-level info, although being requested, for utterance " + strfun::utf8 (key));
// post-process silence
// - first !silence -> !sent_start
// - last !silence -> !sent_end
int silence = (*wordmap)["!silence"];
if (silence >= 0)
{
int sentstart = (*wordmap)["!sent_start"]; // these must have been created
int sentend = (*wordmap)["!sent_end"];
// map first and last !silence to !sent_start and !sent_end, respectively
if (sentstart >= 0 && wordseqbuffer.front().wordindex == (size_t) silence)
wordseqbuffer.front().wordindex = sentstart;
if (sentend >= 0 && wordseqbuffer.back().wordindex == (size_t) silence)
wordseqbuffer.back().wordindex = sentend;
}
//if (sentstart < 0 || sentend < 0 || silence < 0)
// throw std::logic_error ("parseentry: word map must contain !silence, !sent_start, and !sent_end");
// implant
auto & wordsequence = wordsequences[key]; // this creates the map entry
wordsequence.words = wordseqbuffer; // makes a copy
wordsequence.align = alignseqbuffer;
}
}
public:
// return if input statename is sil state (hard code to compared first 3 chars with "sil")
bool issilstate (const string & statename) const // (later use some configuration table)
{
return (statename.size() > 3 && statename.at(0) == 's' && statename.at(1) == 'i' && statename.at(2) == 'l');
}
vector<bool> issilstatetable; // [state index] => true if is sil state (cached)
// return if input stateid represent sil state (by table lookup)
bool issilstate (const size_t id) const
{
assert (id < issilstatetable.size());
return issilstatetable[id];
}
struct nullmap { int operator[] (const char * s) const { throw std::logic_error ("nullmap: should never be used"); } }; // to satisfy a template, never used... :(
// constructor reads multiple MLF files
htkmlfreader (const vector<wstring> & paths, const set<wstring> & restricttokeys, const wstring & stateListPath = L"", const double htkTimeToFrame = 100000.0)
{
// read state list
if (stateListPath != L"")
readstatelist (stateListPath);
// read MLF(s) --note: there can be multiple, so this is a loop
foreach_index (i, paths)
read (paths[i], restricttokeys, (nullmap* /*to satisfy C++ template resolution*/) NULL, (map<string,size_t>*) NULL, htkTimeToFrame);
}
// alternate constructor that optionally also reads word alignments (for MMI training); triggered by providing a 'wordmap'
// (We cannot use an optional arg in the constructor aboe because it interferes with teh template resolution.)
template<typename WORDSYMBOLTABLE, typename UNITSYMBOLTABLE>
htkmlfreader (const vector<wstring> & paths, const set<wstring> & restricttokeys, const wstring & stateListPath, const WORDSYMBOLTABLE * wordmap, const UNITSYMBOLTABLE * unitmap, const double htkTimeToFrame)
{
// read state list
if (stateListPath != L"")
readstatelist (stateListPath);
// read MLF(s) --note: there can be multiple, so this is a loop
foreach_index (i, paths)
read (paths[i], restricttokeys, wordmap, unitmap, htkTimeToFrame);
}
// note: this function is not designed to be pretty but to be fast
template<typename WORDSYMBOLTABLE, typename UNITSYMBOLTABLE>
void read (const wstring & path, const set<wstring> & restricttokeys, const WORDSYMBOLTABLE * wordmap, const UNITSYMBOLTABLE * unitmap, const double htkTimeToFrame)
{
if (!restricttokeys.empty() && this->size() >= restricttokeys.size()) // no need to even read the file if we are there (we support multiple files)
return;
fprintf (stderr, "htkmlfreader: reading MLF file %S ...", path.c_str());
curpath = path; // for error messages only
vector<char> buffer; // buffer owns the characters--don't release until done
vector<char*> lines = readlines (path, buffer);
vector<typename WORDSEQUENCE::word> wordsequencebuffer;
vector<typename WORDSEQUENCE::aligninfo> alignsequencebuffer;
if (lines.empty() || strcmp (lines[0], "#!MLF!#")) malformed ("header missing");
// parse entries
fprintf (stderr, "parse the line %zu\n", lines.size());
size_t line = 1;
while (line < lines.size() && (restricttokeys.empty() || this->size() < restricttokeys.size()))
parseentry (lines, line, restricttokeys, wordmap, unitmap, wordsequencebuffer, alignsequencebuffer, htkTimeToFrame);
curpath.clear();
fprintf (stderr, " total %lu entries\n", this->size());
}
// read state list, index is from 0
void readstatelist (const wstring & stateListPath = L"")
{
if (stateListPath != L"")
{
vector<char> buffer; // buffer owns the characters--don't release until done
vector<char*> lines = readlines (stateListPath, buffer);
size_t index;
issilstatetable.reserve (lines.size());
for (index = 0; index < lines.size(); index++)
{
statelistmap[lines[index]] = index;
issilstatetable.push_back (issilstate (lines[index]));
}
if (index != statelistmap.size())
throw std::runtime_error (msra::strfun::strprintf ("readstatelist: lines (%d) not equal to statelistmap size (%d)", index, statelistmap.size()));
if (statelistmap.size() != issilstatetable.size())
throw std::runtime_error (msra::strfun::strprintf ("readstatelist: size of statelookuparray (%d) not equal to statelistmap size (%d)", issilstatetable.size(), statelistmap.size()));
fprintf (stderr, "total %lu state names in state list %S\n", statelistmap.size(), stateListPath.c_str());
}
}
// return state num: varify the fintune layer dim
size_t getstatenum () const
{
return statelistmap.size();
}
size_t getstateid (string statename) // added by Hang Su adaptation
{
return statelistmap[statename];
}
// access to word sequences
const map<wstring,WORDSEQUENCE> & allwordtranscripts() const { return wordsequences; }
};
};}; // namespaces

Просмотреть файл

@ -1,743 +0,0 @@
//
// <copyright file="latticearchive.cpp" company="Microsoft">
// Copyright (c) Microsoft Corporation. All rights reserved.
// </copyright>
//
#pragma once
#include "stdafx.h"
#include "basetypes.h"
#include "fileutil.h"
#include "htkfeatio.h" // for MLF reading for numer lattices
#include "latticearchive.h"
#include "msra_mgram.h" // for MLF reading for numer lattices
#include <stdio.h>
#include <stdint.h>
#include <vector>
#include <string>
#include <set>
#include <hash_map>
#include <regex>
#pragma warning(disable : 4996)
namespace msra { namespace lattices {
// helper to write a symbol hash (string -> int) to a file
// File has two sections:
// - physicalunitname // line number is mapping, starting with 0
// - logunitname physicalunitname // establishes a mapping; logunitname will get the same numeric index as physicalunitname
template<class UNITMAP>
static void writeunitmap (const wstring & symlistpath, const UNITMAP & unitmap)
{
std::vector<std::string> units;
units.reserve (unitmap.size());
std::vector<std::string> mappings;
mappings.reserve (unitmap.size());
for (auto iter = unitmap.cbegin(); iter != unitmap.cend(); iter++) // why would 'for (auto iter : unitmap)' not work?
{
const std::string label = iter->first;
const size_t unitid = iter->second;
if (units.size() <= unitid)
units.resize (unitid + 1); // we grow it on demand; the result must be compact (all entries filled), we check that later
if (!units[unitid].empty()) // many-to-one mapping: remember the unit; look it up while writing
mappings.push_back (label);
else
units[unitid] = label;
}
auto_file_ptr flist = fopenOrDie (symlistpath, L"wb");
// write (physical) units
foreach_index (k, units)
{
if (units[k].empty())
throw std::logic_error ("build: unitmap has gaps");
fprintfOrDie (flist, "%s\n", units[k].c_str());
}
// write log-phys mappings
foreach_index (k, mappings)
{
const std::string unit = mappings[k]; // logical name
const size_t unitid = unitmap.find (unit)->second; // get its unit id; this indexes the units array
const std::string tounit = units[unitid]; // and get the name from tehre
fprintfOrDie (flist, "%s %s\n", unit.c_str(), tounit.c_str());
}
fflushOrDie (flist);
}
// (little helper to do a map::find() with default value)
template<typename MAPTYPE, typename KEYTYPE, typename VALTYPE>
static size_t tryfind (const MAPTYPE & map, const KEYTYPE & key, VALTYPE deflt)
{
auto iter = map.find (key);
if (iter == map.end())
return deflt;
else
return iter->second;
}
// archive format:
// - output files of build():
// - OUTPATH --the resulting archive (a huge file), simple concatenation of binary blocks
// - OUTPATH.toc --contains keys and offsets; this is how content in archive is found
// KEY=ARCHIVE[BYTEOFFSET] // where ARCHIVE can be empty, meaning same as previous
// - OUTPATH.symlist --list of all unit names encountered, in order of numeric index used in archive (first = index 0)
// This file is suitable as an input to HHEd's AU command.
// - in actual use,
// - .toc files can be concatenated
// - .symlist files must remain paired with the archive file
// - for actual training, user also needs to provide, typically from an HHEd AU run:
// - OUTPATH.tying --map from triphone units to senone sequence by name; get full phone set from .symlist above
// UNITNAME SENONE[2] SENONE[3] SENONE[4]
/*static*/ void archive::build (const std::vector<std::wstring> & infiles, const std::wstring & outpath,
const std::unordered_map<std::string,size_t> & modelsymmap,
const msra::asr::htkmlfreader<msra::asr::htkmlfentry,msra::lattices::lattice::htkmlfwordsequence> & labels, // non-empty: build numer lattices
const msra::lm::CMGramLM & unigram, const msra::lm::CSymbolSet & unigramsymbols) // for numer lattices
{
#if 0 // little unit test helper for testing the read function
bool test = true;
if (test)
{
archive a;
a.open (outpath + L".toc");
lattice L;
std::hash_map<string,size_t> symmap;
a.getlattice (L"sw2001_A_1263622500_1374610000", L, symmap);
a.getlattice (L"sw2001_A_1391162500_1409287500", L, symmap);
return;
}
#endif
const bool numermode = !labels.empty(); // if labels are passed then we shall convert the MLFs to lattices, and 'infiles' are regular keys
const std::wstring tocpath = outpath + L".toc";
const std::wstring symlistpath = outpath + L".symlist";
// process all files
std::set<std::wstring> seenkeys; // (keep track of seen keys; throw error for duplicate keys)
msra::files::make_intermediate_dirs (outpath);
auto_file_ptr f = fopenOrDie (outpath, L"wb");
auto_file_ptr ftoc = fopenOrDie (tocpath, L"wb");
size_t brokeninputfiles = 0;
foreach_index (i, infiles)
{
const std::wstring & inlatpath = infiles[i];
fprintf (stderr, "build: processing lattice '%S'\n", inlatpath.c_str());
// get key
std::wstring key = regex_replace (inlatpath, wregex (L"=.*"), wstring()); // delete mapping
key = regex_replace (key, wregex (L".*[\\\\/]"), wstring()); // delete path
key = regex_replace (key, wregex (L"\\.[^\\.\\\\/:]*$"), wstring()); // delete extension (or not if none)
if (!seenkeys.insert (key).second)
throw std::runtime_error (msra::strfun::strprintf ("build: duplicate key for lattice '%S'", inlatpath.c_str()));
// we fail all the time due to totally broken HDecode/copy process, OK if not too many files are missing
bool latticeread = false;
try
{
// fetch lattice
lattice L;
if (!numermode)
L.fromhtklattice (inlatpath, modelsymmap); // read HTK lattice
else
L.frommlf (key, modelsymmap, labels, unigram, unigramsymbols); // read MLF into a numerator lattice
latticeread = true;
// write to archive
uint64_t offset = fgetpos (f);
L.fwrite (f);
fflushOrDie (f);
// write reference to TOC file --note: TOC file is a headerless UTF8 file; so don't use fprintf %S format (default code page)
fprintfOrDie (ftoc, "%s=%s[%llu]\n", msra::strfun::utf8 (key).c_str(), ((i - brokeninputfiles) == 0) ? msra::strfun::utf8 (outpath).c_str() : "", offset);
fflushOrDie (ftoc);
fprintf (stderr, "written lattice to offset %llu as '%S'\n", offset, key.c_str());
}
catch (const exception & e)
{
if (latticeread) throw; // write failure
// we ignore read failures
fprintf (stderr, "ERROR: skipping unreadable lattice '%S': %s\n", inlatpath.c_str(), e.what());
brokeninputfiles++;
}
}
// write out the unit map
// TODO: This is sort of redundant now--it gets the symmap from the HMM, i.e. always the same for all archives.
writeunitmap (symlistpath, modelsymmap);
fprintf (stderr, "completed %lu out of %lu lattices (%lu read failures, %.1f%%)\n", infiles.size(), infiles.size()-brokeninputfiles, brokeninputfiles, 100.0f * brokeninputfiles / infiles.size());
}
// helper to set a context value (left, right) with checking of uniqueness
void lattice::nodecontext::setcontext (int & lr, int val)
{
if (lr == unknown)
lr = val;
else if (lr != val)
lr = (signed short) ambiguous;
}
// helper for merge() to determine the unique node contexts
vector<lattice::nodecontext> lattice::determinenodecontexts (const msra::asr::simplesenonehmm & hset) const
{
const size_t spunit = tryfind (hset.getsymmap(), "sp", SIZE_MAX);
const size_t silunit = tryfind (hset.getsymmap(), "sil", SIZE_MAX);
vector<lattice::nodecontext> nodecontexts (nodes.size());
nodecontexts.front().left = nodecontext::start;
nodecontexts.front().right = nodecontext::ambiguous; // (should not happen, but won't harm either)
nodecontexts.back().right = nodecontext::end;
nodecontexts.back().left = nodecontext::ambiguous; // (should not happen--we require !sent_end; but who knows)
size_t multispseen = 0; // bad entries with multi-sp
foreach_index (j, edges)
{
const auto & e = edges[j];
const size_t S = e.S;
const size_t E = e.E;
auto a = getaligninfo (j);
if (a.size() == 0) // !NULL edge
throw std::logic_error ("determinenodecontexts: !NULL edges not allowed in merging, should be removed before");
size_t A = a[0].unit;
size_t Z = a[a.size()-1].unit;
if (Z == spunit)
{
if (a.size() < 2)
throw std::runtime_error ("determinenodecontexts: context-free unit (/sp/) found as a single-phone word");
else
{
Z = a[a.size()-2].unit;
if (Z == spunit) // a bugg lattice --I got this from HVite, to be tracked down
{
// search from end once again, to print a warning
int n;
for (n = (int) a.size() -1; n >= 0; n--)
if (a[n].unit != spunit)
break;
// ends with n = position of furthest non-sp
if (n < 0) // only sp?
throw std::runtime_error ("determinenodecontexts: word consists only of /sp/");
fprintf (stderr, "determinenodecontexts: word with %lu /sp/ at the end found, edge %d\n", a.size() -1 - n, j);
multispseen++;
Z = a[n].unit;
}
}
}
if (A == spunit || Z == spunit)
{
#if 0
fprintf (stderr, "A=%d Z=%d fa=%d j=%d/N=%d L=%d n=%d totalalign=%d ts/te=%d/%d\n", (int) A, (int) Z, (int) e.firstalign,(int) j, (int) edges.size(), (int) nodes.size(), (int) a.size(), (int) align.size(),
nodes[S].t, nodes[E].t);
foreach_index (kk, a)
fprintf (stderr, "a[%d] = %d\n", kk, a[kk].unit);
dump (stderr, [&] (size_t i) { return hset.gethmm (i).getname(); });
#endif
throw std::runtime_error ("determinenodecontexts: context-free unit (/sp/) found as a start phone or second last phone");
}
const auto & Ahmm = hset.gethmm (A);
const auto & Zhmm = hset.gethmm (Z);
int Aid = (int) Ahmm.gettransPindex();
int Zid = (int) Zhmm.gettransPindex();
nodecontexts[S].setright (Aid);
nodecontexts[E].setleft (Zid);
}
if (multispseen > 0)
fprintf (stderr, "determinenodecontexts: %lu broken edges in %lu with multiple /sp/ at the end seen\n", multispseen, edges.size());
// check CI conditions and put in 't'
// We make the hard assumption that there is only one CI phone, /sil/.
const auto & silhmm = hset.gethmm (silunit);
int silid = silhmm.gettransPindex();
foreach_index (i, nodecontexts)
{
auto & nc = nodecontexts[i];
if ((nc.left == nodecontext::unknown) ^ (nc.right == nodecontext::unknown))
throw std::runtime_error ("determinenodecontexts: invalid dead-end node in lattice");
if (nc.left == nodecontext::ambiguous && nc.right != silid && nc.right != nodecontext::end)
throw std::runtime_error ("determinenodecontexts: invalid ambiguous left context (right context is not CI)");
if (nc.right == nodecontext::ambiguous && nc.left != silid && nc.left != nodecontext::start)
throw std::runtime_error ("determinenodecontexts: invalid ambiguous right context (left context is not CI)");
nc.t = nodes[i].t;
}
return nodecontexts; // (will this use a move constructor??)
}
// compar function for sorting and merging
bool lattice::nodecontext::operator< (const nodecontext & other) const
{
// sort by t, left, right, i --sort by i to make i appear before iother, as assumed in merge function
int diff = (int) t - (int) other.t;
if (diff == 0)
{
diff = left - other.left;
if (diff == 0)
{
diff = right - other.right;
if (diff == 0)
return i < other.i; // (cannot use 'diff=' pattern since unsigned but may be SIZE_MAX)
}
}
return diff < 0;
}
// remove that final !NULL edge
// We have that in HAPI lattices, but there can be only one at the end.
void lattice::removefinalnull()
{
const auto & lastedge = edges.back();
// last edge can be !NULL, recognized as having 0 alignment records
if (lastedge.firstalign < align.size()) // has alignment records --not !NULL
return;
if (lastedge.S != nodes.size() -2 || lastedge.E != nodes.size() -1)
throw std::runtime_error ("removefinalnull: malformed final !NULL edge");
edges.resize (edges.size() -1); // remove it
nodes.resize (nodes.size() -1); // its start node is now the new end node
foreach_index (j, edges)
if (edges[j].E >= nodes.size())
throw std::runtime_error ("removefinalnull: cannot have final !NULL edge and other edges connecting to end node at the same time");
}
// merge a secondary lattice into the first
// With lots of caveats:
// - this optimizes lattices to true unigram lattices where the only unique node condition is acoustic context
// - no !NULL edge at the end, call removefinalnull() before
// - this function returns an unsorted edges[] array, i.e. invalid. We sort in uniq'ed representation, which is easier.
// This function is not elegant at all, just hard labor!
void lattice::merge (const lattice & other, const msra::asr::simplesenonehmm & hset)
{
if (!edges2.empty() || !other.edges2.empty())
throw std::logic_error ("merge: lattice(s) must be in non-uniq'ed format (V1)");
if (!info.numframes || !other.info.numframes)
throw std::logic_error ("merge: lattice(s) must have identical number of frames");
// establish node contexts
auto contexts = determinenodecontexts (hset);
auto othercontexts = other.determinenodecontexts (hset);
// create joint node space and node mapping
// This also collapses non-unique nodes.
// Note the edge case sil-sil in one lattice which may be sil-ambiguous or ambiguous-sil on the other.
// We ignore this, keeping such nodes unmerged. That's OK since middle /sil/ words have zero LM, and thus it's OK to keep them non-connected.
foreach_index (i, contexts) contexts[i].i = i;
foreach_index (i, othercontexts) othercontexts[i].iother = i;
contexts.insert (contexts.end(), othercontexts.begin(), othercontexts.end()); // append othercontext
sort (contexts.begin(), contexts.end());
vector<size_t> nodemap (nodes.size(), SIZE_MAX);
vector<size_t> othernodemap (other.nodes.size(), SIZE_MAX);
int j = 0;
foreach_index (i, contexts) // merge identical nodes --this is the critical step
{
if (j == 0 || contexts[j-1].t != contexts[i].t || contexts[j-1].left != contexts[i].left || contexts[j-1].right != contexts[i].right)
contexts[j++] = contexts[i]; // entered a new one
// node map
if (contexts[i].i != SIZE_MAX)
nodemap[contexts[i].i] = j-1;
if (contexts[i].iother != SIZE_MAX)
othernodemap[contexts[i].iother] = j-1;
}
fprintf (stderr, "merge: joint node space uniq'ed to %d from %d\n", j, contexts.size());
contexts.resize (j);
// create a new node array (just copy the contexts[].t fields)
nodes.resize (contexts.size());
foreach_index (inew, nodes)
nodes[inew].t = (unsigned short) contexts[inew].t;
info.numnodes = nodes.size();
// incorporate the alignment records
const size_t alignoffset = align.size();
align.insert (align.end(), other.align.begin(), other.align.end());
// map existing edges' S and E fields, and also 'firstalign'
foreach_index (j, edges)
{
edges[j].S = nodemap[edges[j].S];
edges[j].E = nodemap[edges[j].E];
}
auto otheredges = other.edges;
foreach_index (j, otheredges)
{
otheredges[j].S = othernodemap[otheredges[j].S];
otheredges[j].E = othernodemap[otheredges[j].E];
otheredges[j].firstalign += alignoffset; // that's where they are now
}
// at this point, a new 'nodes' array exists, and the edges already are w.r.t. the new node space and align space
// now we are read to merge 'other' edges into this, simply by concatenation
edges.insert (edges.end(), otheredges.begin(), otheredges.end());
// remove acoustic scores --they are likely not identical if they come from different decoders
// If we don't do that, this will break the sorting in builduniquealignments()
info.hasacscores = 0;
foreach_index (j, edges)
edges[j].a = 0.0f;
// Note: we have NOT sorted or de-duplicated yet. That is best done after conversion to the uniq'ed format.
}
// remove duplicates
// This must be called in uniq'ed format.
void lattice::dedup()
{
if (edges2.empty())
throw std::logic_error ("dedup: lattice must be in uniq'ed format (V2)");
size_t k = 0;
foreach_index (j, edges2)
{
if (k > 0 && edges2[k-1].S == edges2[j].S && edges2[k-1].E == edges2[j].E && edges2[k-1].firstalign == edges2[j].firstalign)
{
if (edges2[k-1].implysp != edges2[j].implysp)
throw std::logic_error ("dedup: inconsistent 'implysp' flag for otherwise identical edges");
continue;
}
edges2[k++] = edges2[j];
}
fprintf (stderr, "dedup: edges reduced to %d from %d\n", k, edges2.size());
edges2.resize (k);
info.numedges = edges2.size();
edges.clear(); // (should already be, but isn't; make sure we no longer use it)
}
// load all lattices from a TOC file and write them to a new archive
// Use this to
// - upgrade the file format to latest in case of format changes
// - check consistency (read only; don't write out)
// - dump to stdout
// - merge two lattices (for merging numer into denom lattices)
// Input path is an actual TOC path, output is the stem (.TOC will be added). --yes, not nice, maybe fix it later
// Example command:
// convertlatticearchive --latticetocs dummy c:\smbrdebug\sw20_small.den.lats.toc.10 -w c:\smbrdebug\sw20_small.den.lats.converted --cdphonetying c:\smbrdebug\combined.tying --statelist c:\smbrdebug\swb300h.9304.aligned.statelist --transprobs c:\smbrdebug\MMF.9304.transprobs
// How to regenerate from my test lattices:
// buildlatticearchive c:\smbrdebug\sw20_small.den.lats.regenerated c:\smbrdebug\hvitelat\*lat
// We support two special output path syntaxs:
// - empty ("") -> don't output, just check the format
// - dash ("-") -> dump lattice to stdout instead
/*static*/ void archive::convert (const std::wstring & intocpath, const std::wstring & intocpath2, const std::wstring & outpath,
const msra::asr::simplesenonehmm & hset)
{
const auto & modelsymmap = hset.getsymmap();
const std::wstring tocpath = outpath + L".toc";
const std::wstring symlistpath = outpath + L".symlist";
// open input archive
// TODO: I find that HVite emits redundant physical triphones, and even HHEd seems so (in .tying file).
// Thus, we should uniq the units before sorting. We can do that here if we have the .tying file.
// And then use the modelsymmap to map them down.
// Do this directly in the hset module (it will be transparent).
std::vector<std::wstring> intocpaths (1, intocpath); // set of paths consisting of 1
msra::lattices::archive archive (intocpaths, modelsymmap);
// secondary archive for optional merging operation
const bool mergemode = !intocpath2.empty(); // true if merging two lattices
std::vector<std::wstring> intocpaths2;
if (mergemode)
intocpaths2.push_back (intocpath2);
msra::lattices::archive archive2 (intocpaths2, modelsymmap); // (if no merging then this archive2 is empty)
// read the intocpath file once again to get the keys in original order
std::vector<char> textbuffer;
auto toclines = msra::files::fgetfilelines (intocpath, textbuffer);
auto_file_ptr f = NULL;
auto_file_ptr ftoc = NULL;
// process all files
if (outpath != L"" && outpath != L"-") // test for special syntaxes that bypass to actually create an output archive
{
msra::files::make_intermediate_dirs (outpath);
f = fopenOrDie (outpath, L"wb");
ftoc = fopenOrDie (tocpath, L"wb");
}
vector<const char *> invmodelsymmap; // only used for dump() mode
// we must parse the toc file once again to get the keys in original order
size_t skippedmerges = 0;
foreach_index (i, toclines)
{
const char * line = toclines[i];
const char * p = strchr (line, '=');
if (p == NULL)
throw std::runtime_error ("open: invalid TOC line (no = sign): " + std::string (line));
const std::wstring key = msra::strfun::utf16 (std::string (line, p - line));
fprintf (stderr, "convert: processing lattice '%S'\n", key.c_str());
// fetch lattice --this performs any necessary format conversions already
lattice L;
archive.getlattice (key, L);
lattice L2;
if (mergemode)
{
if (!archive2.haslattice (key))
{
fprintf (stderr, "convert: cannot merge because lattice '%S' missing in secondary archive; skipping\n", key.c_str());
skippedmerges++;
continue;
}
archive2.getlattice (key, L2);
// merge it in
// This will connect each node with matching 1-phone context conditions; aimed at merging numer lattices.
L.removefinalnull(); // get rid of that final !NULL headache
L2.removefinalnull();
L.merge (L2, hset);
// note: we are left with dups due to true unigram merging (HTK lattices cannot represent true unigram lattices since id is on the nodes)
}
//L.removefinalnull();
//L.determinenodecontexts (hset);
// convert it --TODO: once we permanently use the new format, do this in fread() for V1
// Note: Merging may have left this in unsorted format; we need to be robust against that.
const size_t spunit = tryfind (modelsymmap, "sp", SIZE_MAX);
L.builduniquealignments (spunit);
if (mergemode)
L.dedup();
if (f && ftoc)
{
// write to archive
uint64_t offset = fgetpos (f);
L.fwrite (f);
fflushOrDie (f);
// write reference to TOC file --note: TOC file is a headerless UTF8 file; so don't use fprintf %S format (default code page)
fprintfOrDie (ftoc, "%s=%s[%llu]\n", msra::strfun::utf8 (key).c_str(), (i == 0) ? msra::strfun::utf8 (outpath).c_str() : "", offset);
fflushOrDie (ftoc);
fprintf (stderr, "written converted lattice to offset %llu as '%S'\n", offset, key.c_str());
}
else if (outpath == L"-")
{
if (invmodelsymmap.empty()) // build this lazily
{
invmodelsymmap.resize (modelsymmap.size());
for (auto iter = modelsymmap.begin(); iter != modelsymmap.end(); iter++)
invmodelsymmap[iter->second] = iter->first.c_str();
}
L.rebuildedges (false);
L.dump (stdout, [&] (size_t i) { return invmodelsymmap[i]; } );
}
} // end for (toclines)
if (skippedmerges > 0)
fprintf (stderr, "convert: %d out of %d merge operations skipped due to secondary lattice missing\n", skippedmerges, toclines.size());
// write out the updated unit map
if (f && ftoc)
writeunitmap (symlistpath, modelsymmap);
fprintf (stderr, "converted %d lattices\n", toclines.size());
}
// ---------------------------------------------------------------------------
// reading lattices from external formats (HTK lat, MLF)
// ---------------------------------------------------------------------------
// read an HTK lattice
// The lattice is expected to be freshly constructed (I did not bother to check).
void lattice::fromhtklattice (const wstring & path, const std::unordered_map<std::string,size_t> & unitmap)
{
vector<char> textbuffer;
auto lines = msra::files::fgetfilelines (path, textbuffer);
if (lines.empty())
throw std::runtime_error ("lattice: mal-formed lattice--empty input file (or all-zeroes)");
auto iter = lines.begin();
// parse out LMF and WP
char dummychar = 0; // dummy for sscanf() end checking
for ( ; iter != lines.end() && strncmp (*iter, "N=", 2); iter++)
{
if (strncmp (*iter, "lmscale=", 8) == 0) // note: HTK sometimes generates extra garbage space at the end of this line
if (sscanf_s (*iter, "lmscale=%f wdpenalty=%f%c", &info.lmf, &info.wp, &dummychar, sizeof (dummychar)) != 2 && dummychar != ' ')
throw std::runtime_error ("lattice: mal-formed lmscale/wdpenalty line in lattice: " + string (*iter));
}
// parse N and L
if (iter != lines.end())
{
unsigned long N, L;
if (sscanf_s (*iter, "N=%lu L=%lu %c", &N, &L, &dummychar, sizeof (dummychar)) != 2)
throw std::runtime_error ("lattice: mal-formed N=/L= line in lattice: " + string (*iter));
info.numnodes = N;
info.numedges = L;
iter++;
}
else
throw std::runtime_error ("lattice: mal-formed before parse N=/L= line in lattice.");
ASSERT(info.numnodes > 0);
nodes.reserve (info.numnodes);
// parse the nodes
for (size_t i = 0; i < info.numnodes; i++, iter++)
{
if (iter == lines.end())
throw std::runtime_error ("lattice: not enough I lines in lattice");
unsigned long itest;
float t;
if (sscanf_s (*iter, "I=%lu t=%f%c", &itest, &t, &dummychar, sizeof (dummychar)) < 2)
throw std::runtime_error ("lattice: mal-formed node line in lattice: " + string (*iter));
if (i != (size_t) itest)
throw std::runtime_error ("lattice: out-of-sequence node line in lattice: " + string (*iter));
nodes.push_back (nodeinfo ((unsigned int) (t / info.frameduration + 0.5)));
info.numframes = max (info.numframes, (size_t) nodes.back().t);
}
// parse the edges
ASSERT(info.numedges > 0);
edges.reserve (info.numedges);
align.reserve (info.numedges * 10); // 10 phones per word on av. should be enough
std::string label;
for (size_t j = 0; j < info.numedges; j++, iter++)
{
if (iter == lines.end())
throw std::runtime_error ("lattice: not enough J lines in lattice");
unsigned long jtest;
unsigned long S, E;
float a, l;
char d[1024];
// example:
// J=12 S=1 E=13 a=-326.81 l=-5.090 d=:sil-t:s+k:e,0.03:dh:m-ax:m+sil,0.03:sil,0.02:
int nvals = sscanf_s (*iter, "J=%lu S=%lu E=%lu a=%f l=%f d=%s", &jtest, &S, &E, &a, &l, &d, sizeof (d));
if (nvals == 5 && j == info.numedges - 1) // special case: last edge is a !NULL and thus may have the d= record missing
strcpy (d, ":");
else if (nvals != 6)
throw std::runtime_error ("lattice: mal-formed edge line in lattice: " + string (*iter));
if (j != (size_t) jtest)
throw std::runtime_error ("lattice: out-of-sequence edge line in lattice: " + string (*iter));
edges.push_back (edgeinfowithscores (S, E, a, l, align.size()));
// build align array
size_t edgeframes = 0; // (for checking whether the alignment sums up right)
const char * p = d;
if (p[0] != ':' || (p[1] == 0 && j < info.numedges-1)) // last edge may be empty
throw std::runtime_error ("lattice: alignment info must start with a colon and must have at least one entry: " + string (*iter));
p++;
while (*p)
{
// p points to an entry of the form TRIPHONE,DURATION
const char * q = strchr (p, ',');
if (q == NULL)
throw std::runtime_error ("lattice: alignment entry lacking a comma: " + string (*iter));
if (q == p)
throw std::runtime_error ("lattice: alignment entry label empty: " + string (*iter));
label.assign (p, q-p); // the triphone label
q++;
char * ep;
double duration = strtod (q, &ep); // (weird--returns a non-const ptr in ep to a const object)
p = ep;
if (*p != ':')
throw std::runtime_error ("lattice: alignment entry not ending with a colon: " + string (*iter));
p++;
// create the alignment entry
const size_t frames = (unsigned int) (duration / info.frameduration + 0.5);
auto it = unitmap.find (label);
if (it == unitmap.end())
throw std::runtime_error ("lattice: unit in alignment that is not in model: " + label);
const size_t unitid = it->second;
//const size_t unitid = unitmap.insert (make_pair (label, unitmap.size())).first->second; // may create a new entry with index = #entries
align.push_back (aligninfo (unitid, frames));
edgeframes += frames;
}
if (edgeframes != nodes[E].t - (size_t) nodes[S].t)
{
char msg[128];
sprintf (msg, "\n-- where edgeframes=%d != (nodes[E].t - nodes[S].t=%d), the gap is %d.", edgeframes, nodes[E].t - (size_t) nodes[S].t, edgeframes + nodes[S].t - nodes[E].t);
throw std::runtime_error ("lattice: alignment info duration mismatches edge duration: " + string (*iter) + msg);
}
}
if (iter != lines.end())
throw std::runtime_error ("lattice: unexpected garbage at end of lattice: " + string (*iter));
checklattice();
// create more efficient storage for alignments
const size_t spunit = tryfind (unitmap, "sp", SIZE_MAX);
builduniquealignments (spunit);
showstats();
}
// construct a numerator lattice from an MLF entry
// The lattice is expected to be freshly constructed (I did not bother to check).
void lattice::frommlf (const wstring & key, const std::unordered_map<std::string,size_t> & unitmap,
const msra::asr::htkmlfreader<msra::asr::htkmlfentry,lattice::htkmlfwordsequence> & labels,
const msra::lm::CMGramLM & unigram, const msra::lm::CSymbolSet & unigramsymbols)
{
const auto & transcripts = labels.allwordtranscripts(); // (TODO: we could just pass the transcripts map--does not really matter)
// get the labels (state and word)
auto iter = transcripts.find (key);
if (iter == transcripts.end())
throw std::runtime_error ("frommlf: no reference word sequence in MLF for lattice with key " + strfun::utf8 (key));
const auto & transcript = iter->second;
if (transcript.words.size() == 0)
throw std::runtime_error ("frommlf: empty reference word sequence for lattice with key " + strfun::utf8 (key));
// determine unigram scores for all words
vector<float> lmscores (transcript.words.size());
size_t silence = unigramsymbols["!silence"];
size_t lmend = unigramsymbols["</s>"];
size_t sentstart = unigramsymbols["!sent_start"];
size_t sentend = unigramsymbols["!sent_end"];
// create the lattice
nodes.resize (transcript.words.size() +1);
edges.resize (transcript.words.size());
align.reserve (transcript.align.size());
size_t numframes = 0;
foreach_index (j, transcript.words)
{
const auto & w = transcript.words[j];
nodes[j].t = w.firstframe;
auto & e = edges[j];
e.unused = 0;
e.S = j;
e.E = j+1;
if (e.E != j+1)
throw std::runtime_error (msra::strfun::strprintf ("frommlf: too many tokens to be represented as edgeinfo::E in label set: %S", key.c_str()));
e.a = 0.0f; // no ac score
// LM score
// !sent_start and !silence are patched to LM score 0
size_t wid = w.wordindex;
if (wid == sentstart)
{
if (j != 0)
throw std::logic_error ("frommlf: found an !sent_start token not at the first position");
}
else if (wid == sentend)
{
if (j != (int) transcript.words.size()-1)
throw std::logic_error ("frommlf: found an !sent_end token not at the end position");
wid = lmend; // use </s> for score lookup
}
const int iwid = (int) wid;
e.l = (wid != sentstart && wid != silence) ? (float) unigram.score (&iwid, 1) : 0.0f;
// alignment
e.implysp = 0;
e.firstalign = align.size();
auto a = transcript.getaligninfo (j);
align.insert (align.end(), a.begin(), a.end());
foreach_index (k, a)
numframes += a[k].frames;
}
nodes[transcript.words.size()].t = (unsigned short) numframes;
if (nodes[transcript.words.size()].t != numframes)
throw std::runtime_error (msra::strfun::strprintf ("frommlf: too many frames to be represented as nodeinfo::t in label set: %S", key.c_str()));
info.lmf = -1.0f; // indicates not set
info.wp = 0.0f; // not set indicated by lmf < 0
info.numedges = edges.size();
info.numnodes = nodes.size();
info.numframes = numframes;
checklattice();
// create more efficient storage for alignments
const size_t spunit = tryfind (unitmap, "sp", SIZE_MAX);
builduniquealignments (spunit);
showstats();
}
};};

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -1,119 +0,0 @@
//
// <copyright file="latticestorage.h" company="Microsoft">
// Copyright (c) Microsoft Corporation. All rights reserved.
// </copyright>
//
// latticestorage.h -- basic data structures for storing lattices
#if 0 // [v-hansu] separate code with history
#endif
#pragma once
#include <string> // for the error message in checkoverflow() only
#include <stdexcept>
#include <stdint.h>
#undef INITIAL_STRANGE // [v-hansu] intialize structs to strange values
#define PARALLEL_SIL // [v-hansu] process sil on CUDA, used in other files, please search this
#define LOGZERO -1e30f
namespace msra { namespace lattices {
static void checkoverflow (size_t fieldval, size_t targetval, const char * fieldname)
{
if (fieldval != targetval)
{
char buf[1000];
sprintf_s (buf, "lattice: bit field %s too small for value 0x%zu (cut from 0x%zu)", fieldname, targetval, fieldval);
throw std::runtime_error (buf);
}
}
struct nodeinfo
{
//uint64_t firstinedge : 24; // index of first incoming edge
//uint64_t firstoutedge : 24; // index of first outgoing edge
//uint64_t t : 16; // time associated with this
unsigned short t; // time associated with this
nodeinfo (size_t pt) : t ((unsigned short) pt) //, firstinedge (NOEDGE), firstoutedge (NOEDGE)
{
checkoverflow (t, pt, "nodeinfo::t");
//checkoverflow (firstinedge, NOEDGE, "nodeinfo::firstinedge");
//checkoverflow (firstoutedge, NOEDGE, "nodeinfo::firstoutedge");
}
nodeinfo() // [v-hansu] initialize to impossible values
{
#ifdef INITIAL_STRANGE
t = unsigned short (-1);
#endif
}
};
// V2 format: a and l are stored in separate vectors
struct edgeinfo
{
uint64_t S : 19; // start node
uint64_t unused : 1; // (for future use)
uint64_t E : 19; // end node
uint64_t implysp : 1; // 1--alignment ends with a /sp/ that is not stored
uint64_t firstalign : 24; // index into align for first entry; end is firstalign of next edge
edgeinfo (size_t pS, size_t pE, size_t pfirstalign) : S (pS), E (pE), firstalign (pfirstalign), unused (0), implysp (0)
{
checkoverflow (S, pS, "edgeinfowithscores::S");
checkoverflow (E, pE, "edgeinfowithscores::E");
checkoverflow (firstalign, pfirstalign, "edgeinfowithscores::firstalign");
}
edgeinfo() // [v-hansu] initialize to impossible values
{
#ifdef INITIAL_STRANGE
S = uint64_t (-1);
unused = uint64_t (-1);
E = uint64_t (-1);
implysp = uint64_t (-1);
firstalign = uint64_t (-1);
#endif
}
};
// V1 format: a and l are included in the edge itself
struct edgeinfowithscores : edgeinfo
{
float a;
float l;
edgeinfowithscores (size_t pS, size_t pE, float a, float l, size_t pfirstalign) : edgeinfo (pS, pE, pfirstalign), a(a), l(l) {}
edgeinfowithscores() // [v-hansu] initialize to impossible values
{
#ifdef INITIAL_STRANGE
a = LOGZERO;
l = LOGZERO;
#endif
}
};
struct aligninfo // phonetic alignment
{
unsigned int unit : 19; // triphone index
unsigned int frames : 11; // duration in frames
// note: V1 did not have the following, which were instead the two 2 bits of 'frames'
unsigned int unused : 1; // (for future use)
unsigned int last : 1; // set for last entry
aligninfo (size_t punit, size_t pframes) : unit ((unsigned int) punit), frames ((unsigned int) pframes), unused (0), last (0)
{
checkoverflow (unit, punit, "aligninfo::unit");
checkoverflow (frames, pframes, "aligninfo::frames");
}
aligninfo() // [v-hansu] initialize to impossible values
{
#ifdef INITIAL_STRANGE
unit = unsigned int (-1);
frames = unsigned int (-1);
unused = unsigned int (-1);
last = unsigned int (-1);
#endif
}
template<class IDMAP> void updateunit (const IDMAP & idmap/*[unit] -> new unit*/) // update 'unit' w.r.t. a different mapping, with bit-field overflow check
{
const size_t mappedunit = idmap[unit];
unit = (unsigned int) mappedunit;
checkoverflow (unit, mappedunit, "aligninfo::unit");
}
};
};};

Просмотреть файл

@ -1,299 +0,0 @@
//
// <copyright file="minibatchiterator.h" company="Microsoft">
// Copyright (c) Microsoft Corporation. All rights reserved.
// </copyright>
//
// minibatchiterator.h -- iterator for minibatches
#pragma once
#define NONUMLATTICEMMI // [v-hansu] move from main.cpp, no numerator lattice for mmi training
#include <vector>
#include <unordered_map>
#include "ssematrix.h"
#include "latticearchive.h" // for reading HTK phoneme lattices (MMI training)
#include "simple_checked_arrays.h" // for const_array_ref
namespace msra { namespace dbn {
// ---------------------------------------------------------------------------
// latticesource -- manages loading of lattices for MMI (in pairs for numer and denom)
// ---------------------------------------------------------------------------
class latticesource
{
const msra::lattices::archive numlattices, denlattices;
public:
latticesource (std::pair<std::vector<wstring>,std::vector<wstring>> latticetocs, const std::unordered_map<std::string,size_t> & modelsymmap)
: numlattices (latticetocs.first, modelsymmap), denlattices (latticetocs.second, modelsymmap) {}
bool empty() const
{
#ifndef NONUMLATTICEMMI // TODO:set NUM lattice to null so as to save memory
if (numlattices.empty() ^ denlattices.empty())
throw std::runtime_error("latticesource: numerator and denominator lattices must be either both empty or both not empty");
#endif
return denlattices.empty();
}
bool haslattice (wstring key) const
{
#ifdef NONUMLATTICEMMI
return denlattices.haslattice (key);
#else
return numlattices.haslattice (key) && denlattices.haslattice (key);
#endif
}
class latticepair : public pair<msra::lattices::lattice,msra::lattices::lattice>
{
public:
// NOTE: we don't check numerator lattice now
size_t getnumframes () const { return second.getnumframes(); }
size_t getnumnodes () const { return second.getnumnodes(); }
size_t getnumedges () const { return second.getnumedges(); }
wstring getkey () const { return second.getkey(); }
};
void getlattices (const std::wstring & key, shared_ptr<const latticesource::latticepair> & L, size_t expectedframes) const
{
shared_ptr<latticepair> LP (new latticepair);
denlattices.getlattice (key, LP->second, expectedframes); // this loads the lattice from disk, using the existing L.second object
L = LP;
}
};
// ---------------------------------------------------------------------------
// minibatchsource -- abstracted interface into frame sources
// There are three implementations:
// - the old minibatchframesource to randomize across frames and page to disk
// - minibatchutterancesource that randomizes in chunks and pages from input files directly
// - a wrapper that uses a thread to read ahead in parallel to CPU/GPU processing
// ---------------------------------------------------------------------------
class minibatchsource
{
public:
// read a minibatch
// This function returns all values in a "caller can keep them" fashion:
// - uids are stored in a huge 'const' array, and will never go away
// - transcripts are copied by value
// - lattices are returned as a shared_ptr
// Thus, getbatch() can be called in a thread-safe fashion, allowing for a 'minibatchsource' implementation that wraps another with a read-ahead thread.
// Return value is 'true' if it did read anything from disk, and 'false' if data came only from RAM cache. This is used for controlling the read-ahead thread.
virtual bool getbatch (const size_t globalts,
const size_t framesrequested, msra::dbn::matrix & feat, std::vector<size_t> & uids,
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> & transcripts,
std::vector<shared_ptr<const latticesource::latticepair>> & lattices) = 0;
// alternate (updated) definition for multiple inputs/outputs - read as a vector of feature matrixes or a vector of label strings
virtual bool getbatch (const size_t globalts,
const size_t framesrequested, std::vector<msra::dbn::matrix> & feat, std::vector<std::vector<size_t>> & uids,
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> & transcripts,
std::vector<shared_ptr<const latticesource::latticepair>> & lattices) = 0;
// getbatch() overload to support subsetting of mini-batches for parallel training
// Default implementation does not support subsetting and throws an exception on
// calling this overload with a numsubsets value other than 1.
virtual bool getbatch(const size_t globalts,
const size_t framesrequested, const size_t subsetnum, const size_t numsubsets, size_t & framesadvanced,
std::vector<msra::dbn::matrix> & feat, std::vector<std::vector<size_t>> & uids,
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> & transcripts,
std::vector<shared_ptr<const latticesource::latticepair>> & lattices)
{
assert((subsetnum == 0) && (numsubsets == 1) && !supportsbatchsubsetting()); subsetnum; numsubsets;
bool retVal = getbatch(globalts, framesrequested, feat, uids, transcripts, lattices);
framesadvanced = feat[0].cols();
return retVal;
}
virtual bool supportsbatchsubsetting() const
{
return false;
}
virtual size_t totalframes() const = 0;
virtual double gettimegetbatch () = 0; // used to report runtime
virtual size_t firstvalidglobalts (const size_t globalts) = 0; // get first valid epoch start from intended 'globalts'
virtual const std::vector<size_t> & unitcounts() const = 0; // report number of senones
virtual void setverbosity(int newverbosity) = 0;
virtual ~minibatchsource() { }
};
// ---------------------------------------------------------------------------
// minibatchiterator -- class to iterate over one epoch, minibatch by minibatch
// This iterator supports both random frames and random utterances through the minibatchsource interface whichis common to both.
// This supports multiple data passes with identical randomization; which is intended to be used for utterance-based training.
// ---------------------------------------------------------------------------
class minibatchiterator
{
void operator= (const minibatchiterator &); // (non-copyable)
const size_t epochstartframe;
const size_t epochendframe;
size_t firstvalidepochstartframe; // epoch start frame rounded up to first utterance boundary after epoch boundary
const size_t requestedmbframes; // requested mb size; actual minibatches can be smaller (or even larger for lattices)
const size_t datapasses; // we return the data this many times; caller must sub-sample with 'datapass'
msra::dbn::minibatchsource & source; // feature source to read from
// subset to read during distributed data-parallel training (no subsetting: (0,1))
size_t subsetnum;
size_t numsubsets;
std::vector<msra::dbn::matrix> featbuf; // buffer for holding curernt minibatch's frames
std::vector<std::vector<size_t>> uids; // buffer for storing current minibatch's frame-level label sequence
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> transcripts; // buffer for storing current minibatch's word-level label sequences (if available and used; empty otherwise)
std::vector<shared_ptr<const latticesource::latticepair>> lattices; // lattices of the utterances in current minibatch (empty in frame mode)
size_t mbstartframe; // current start frame into generalized time line (used for frame-wise mode and for diagnostic messages)
size_t actualmbframes; // actual number of frames in current minibatch
size_t mbframesadvanced; // logical number of frames the current MB represents (to advance time; > featbuf.cols() possible, intended for the case of distributed data-parallel training)
size_t datapass; // current datapass = pass through the data
double timegetbatch; // [v-hansu] for time measurement
double timechecklattice;
private:
// fetch the next mb
// This updates featbuf, uids[], mbstartframe, and actualmbframes.
void fillorclear()
{
if (!hasdata()) // we hit the end of the epoch: just cleanly clear out everything (not really needed, can't be requested ever)
{
foreach_index(i, featbuf)
featbuf[i].resize (0, 0);
foreach_index(i,uids)
uids[i].clear();
transcripts.clear();
actualmbframes = 0;
return;
}
// process one mini-batch (accumulation and update)
assert (requestedmbframes > 0);
const size_t requestedframes = min (requestedmbframes, epochendframe - mbstartframe); // (< mbsize at end)
assert (requestedframes > 0);
source.getbatch (mbstartframe, requestedframes, subsetnum, numsubsets, mbframesadvanced, featbuf, uids, transcripts, lattices);
timegetbatch = source.gettimegetbatch();
actualmbframes = featbuf[0].cols(); // for single i/o, there featbuf is length 1
// note:
// - in frame mode, actualmbframes may still return less if at end of sweep
// - in utterance mode, it likely returns less than requested, and
// it may also be > epochendframe (!) for the last utterance, which, most likely, crosses the epoch boundary
// - in case of data parallelism, featbuf.cols() < mbframesadvanced
auto_timer timerchecklattice;
if (!lattices.empty())
{
size_t totalframes = 0;
foreach_index (i, lattices)
totalframes += lattices[i]->getnumframes();
if (totalframes != actualmbframes)
throw std::logic_error ("fillorclear: frames in lattices do not match minibatch size");
}
timechecklattice = timerchecklattice;
}
bool hasdata() const { return mbstartframe < epochendframe; } // true if we can access and/or advance
void checkhasdata() const { if (!hasdata()) throw std::logic_error ("minibatchiterator: access beyond end of epoch"); }
public:
// interface: for (minibatchiterator i (...), i, i++) { ... }
minibatchiterator (msra::dbn::minibatchsource & source, size_t epoch, size_t epochframes, size_t requestedmbframes, size_t subsetnum, size_t numsubsets, size_t datapasses)
: source (source),
epochstartframe (epoch * epochframes),
epochendframe (epochstartframe + epochframes),
requestedmbframes (requestedmbframes),
subsetnum(subsetnum), numsubsets(numsubsets),
datapasses (datapasses),
timegetbatch (0), timechecklattice (0)
{
firstvalidepochstartframe = source.firstvalidglobalts (epochstartframe); // epochstartframe may fall between utterance boundaries; this gets us the first valid boundary
fprintf (stderr, "minibatchiterator: epoch %d: frames [%d..%d] (first utterance at frame %d), data subset %d of %d, with %d datapasses\n",
epoch, epochstartframe, epochendframe, firstvalidepochstartframe, subsetnum, numsubsets, datapasses);
mbstartframe = firstvalidepochstartframe;
datapass = 0;
fillorclear(); // get the first batch
}
// TODO not nice, but don't know how to access these frames otherwise
// mbiterator constructor, set epochstart and -endframe explicitly
minibatchiterator(msra::dbn::minibatchsource & source, size_t epoch, size_t epochstart, size_t epochend, size_t requestedmbframes, size_t subsetnum, size_t numsubsets, size_t datapasses)
: source (source),
epochstartframe (epochstart),
epochendframe (epochend),
requestedmbframes (requestedmbframes),
subsetnum(subsetnum), numsubsets(numsubsets),
datapasses (datapasses),
timegetbatch (0), timechecklattice (0)
{
firstvalidepochstartframe = source.firstvalidglobalts (epochstartframe); // epochstartframe may fall between utterance boundaries; this gets us the first valid boundary
fprintf (stderr, "minibatchiterator: epoch %d: frames [%d..%d] (first utterance at frame %d), data subset %d of %d, with %d datapasses\n",
epoch, epochstartframe, epochendframe, firstvalidepochstartframe, subsetnum, numsubsets, datapasses);
mbstartframe = firstvalidepochstartframe;
datapass = 0;
fillorclear(); // get the first batch
}
// need virtual destructor to ensure proper destruction
virtual ~minibatchiterator()
{}
// returns true if we still have data
operator bool() const { return hasdata(); }
// advance to the next minimb
void operator++(int/*denotes postfix version*/)
{
checkhasdata();
mbstartframe += mbframesadvanced;
// if we hit the end, we will get mbstartframe >= epochendframe <=> !hasdata()
// (most likely actually mbstartframe > epochendframe since the last utterance likely crosses the epoch boundary)
// in case of multiple datapasses, reset to start when hitting the end
if (!hasdata() && datapass + 1 < datapasses)
{
mbstartframe = firstvalidepochstartframe;
datapass++;
fprintf (stderr, "\nminibatchiterator: entering %zu-th repeat pass through the data\n", datapass+1);
}
fillorclear();
}
// accessors to current minibatch
size_t currentmbstartframe() const { return mbstartframe; }
size_t currentmbframes() const { return actualmbframes; }
size_t currentmbframesadvanced() const { return mbframesadvanced; }
size_t currentmblattices() const { return lattices.size(); }
size_t currentdatapass() const { return datapass; } // 0..datapasses-1; use this for sub-sampling
size_t requestedframes() const {return requestedmbframes; }
double gettimegetbatch () {return timegetbatch;}
double gettimechecklattice () {return timechecklattice;}
bool isfirst() const { return mbstartframe == firstvalidepochstartframe && datapass == 0; }
float progress() const // (note: 100%+eps possible for last utterance)
{
const float epochframes = (float) (epochendframe - epochstartframe);
return (mbstartframe + mbframesadvanced - epochstartframe + datapass * epochframes) / (datapasses * epochframes);
}
std::pair<size_t,size_t> range() const { return make_pair (epochstartframe, epochendframe); }
// return the current minibatch frames as a matrix ref into the feature buffer
// Number of frames is frames().cols() == currentmbframes().
// For frame-based randomization, this is 'requestedmbframes' most of the times, while for utterance randomization,
// this depends highly on the utterance lengths.
// User is allowed to manipulate the frames... for now--TODO: move silence filtering here as well
msra::dbn::matrixstripe frames(size_t i) { checkhasdata(); assert(featbuf.size()>=i+1); return msra::dbn::matrixstripe (featbuf[i], 0, actualmbframes); }
msra::dbn::matrixstripe frames() { checkhasdata(); assert(featbuf.size()==1); return msra::dbn::matrixstripe (featbuf[0], 0, actualmbframes); }
// return the reference transcript labels (state alignment) for current minibatch
/*const*/ std::vector<size_t> & labels() { checkhasdata(); assert(uids.size()==1);return uids[0]; }
/*const*/ std::vector<size_t> & labels(size_t i) { checkhasdata(); assert(uids.size()>=i+1); return uids[i]; }
// return a lattice for an utterance (caller should first get total through currentmblattices())
shared_ptr<const msra::dbn::latticesource::latticepair> lattice (size_t uttindex) const { return lattices[uttindex]; } // lattices making up the current
// return the reference transcript labels (words with alignments) for current minibatch (or empty if no transcripts requested)
const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word> transcript (size_t uttindex) { return transcripts.empty() ? const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>() : transcripts[uttindex]; }
};
};};

Просмотреть файл

@ -1,279 +0,0 @@
//
// <copyright file="minibatchsourcehelpers.h" company="Microsoft">
// Copyright (c) Microsoft Corporation. All rights reserved.
// </copyright>
//
// minibatchsourcehelpers.h -- helper classes for minibatch sources
//
#pragma once
#include "basetypes.h"
#include <stdio.h>
#include <vector>
#include <algorithm>
#ifndef __unix__
#include "ssematrix.h" // for matrix type
#endif
namespace msra { namespace dbn {
// ---------------------------------------------------------------------------
// augmentneighbors() -- augmenting features with their neighbor frames
// ---------------------------------------------------------------------------
// implant a sub-vector into a vector, for use in augmentneighbors
template<class INV, class OUTV> static void copytosubvector (const INV & inv, size_t subvecindex, OUTV & outv)
{
size_t subdim = inv.size();
assert (outv.size() % subdim == 0);
size_t k0 = subvecindex * subdim;
foreach_index (k, inv)
outv[k + k0] = inv[k];
}
// compute the augmentation extent (how many frames added on each side)
static size_t augmentationextent (size_t featdim/*augment from*/, size_t modeldim/*to*/)
{
const size_t windowframes = modeldim / featdim; // total number of frames to generate
const size_t extent = windowframes / 2; // extend each side by this
if (modeldim % featdim != 0)
throw runtime_error ("augmentationextent: model vector size not multiple of input features");
if (windowframes % 2 == 0)
throw runtime_error (msra::strfun::strprintf ("augmentationextent: neighbor expansion of input features to %d not symmetrical", windowframes));
return extent;
}
// augment neighbor frames for a frame (correctly not expanding across utterance boundaries)
// The boundaryflags[] array, if not empty, flags first (-1) and last (+1) frame, i.e. frames to not expand across.
// The output 'v' must have te-ts columns.
template<class MATRIX, class VECTOR> static void augmentneighbors (const MATRIX & frames, const std::vector<char> & boundaryflags, size_t t,
VECTOR & v)
{
// how many frames are we adding on each side
const size_t extent = augmentationextent (frames[t].size(), v.size());
// copy the frame and its neighbors
// Once we hit a boundaryflag in either direction, do not move index beyond.
copytosubvector (frames[t], extent, v); // frame[t] sits right in the middle
size_t t1 = t; // index for frames on to the left
size_t t2 = t; // and right
for (size_t n = 1; n <= extent; n++)
{
#ifdef SAMPLING_EXPERIMENT
if (boundaryflags.empty()) // boundary flags not given: 'frames' is full utterance
{
if (t1 >= SAMPLING_EXPERIMENT) t1 -= SAMPLING_EXPERIMENT; // index does not move beyond boundary
if (t2 + SAMPLING_EXPERIMENT < frames.size()) t2 += SAMPLING_EXPERIMENT;
}
else
{
if (boundaryflags[t1] != -1) t1 -= SAMPLING_EXPERIMENT; // index does not move beyond a set boundaryflag,
if (boundaryflags[t2] != 1) t2 += SAMPLING_EXPERIMENT; // because that's the start/end of the utterance
}
#else
if (boundaryflags.empty()) // boundary flags not given: 'frames' is full utterance
{
if (t1 > 0) t1--; // index does not move beyond boundary
if (t2 + 1 < frames.size()) t2++;
}
else
{
if (boundaryflags[t1] != -1) t1--; // index does not move beyond a set boundaryflag,
if (boundaryflags[t2] != 1) t2++; // because that's the start/end of the utterance
}
#endif
copytosubvector (frames[t1], extent - n, v);
copytosubvector (frames[t2], extent + n, v);
}
}
// augment neighbor frames for a frame (correctly not expanding across utterance boundaries)
// The boundaryflags[] array, if not empty, flags first (-1) and last (+1) frame, i.e. frames to not expand across.
// The output 'v' must have te-ts columns.
template<class MATRIX, class VECTOR> static void augmentneighbors(const MATRIX & frames, const std::vector<char> & boundaryflags, size_t t, const size_t leftextent, const size_t rightextent,
VECTOR & v)
{
// copy the frame and its neighbors
// Once we hit a boundaryflag in either direction, do not move index beyond.
copytosubvector(frames[t], leftextent, v); // frame[t] sits right in the middle
size_t t1 = t; // index for frames on to the left
size_t t2 = t; // and right
for (size_t n = 1; n <= leftextent; n++)
{
if (boundaryflags.empty()) // boundary flags not given: 'frames' is full utterance
{
if (t1 > 0) t1--; // index does not move beyond boundary
}
else
{
if (boundaryflags[t1] != -1) t1--; // index does not move beyond a set boundaryflag,
}
copytosubvector(frames[t1], leftextent - n, v);
}
for (size_t n = 1; n <= rightextent; n++)
{
if (boundaryflags.empty()) // boundary flags not given: 'frames' is full utterance
{
if (t2 + 1 < frames.size()) t2++;
}
else
{
if (boundaryflags[t2] != 1) t2++; // because that's the start/end of the utterance
}
copytosubvector(frames[t2], rightextent + n, v);
}
}
// augment neighbor frames for one frame t in frames[] according to boundaryflags[]; result returned in column j of v
template<class INMATRIX, class OUTMATRIX> static void augmentneighbors (const INMATRIX & frames, const std::vector<char> & boundaryflags, size_t t,
OUTMATRIX & v, size_t j)
{
auto v_j = v.col(j); // the vector to fill in
augmentneighbors (frames, boundaryflags, t, v_j);
}
// augment neighbor frames for one frame t in frames[] according to boundaryflags[]; result returned in column j of v
template<class INMATRIX, class OUTMATRIX> static void augmentneighbors(const INMATRIX & frames, const std::vector<char> & boundaryflags, size_t t, size_t leftextent, size_t rightextent,
OUTMATRIX & v, size_t j)
{
auto v_j = v.col(j); // the vector to fill in
augmentneighbors(frames, boundaryflags, t, leftextent, rightextent, v_j);
}
// augment neighbor frames for a sequence of frames (part of an utterance, possibly spanning across boundaries)
template<class MATRIX> static void augmentneighbors (const std::vector<std::vector<float>> & frames, const std::vector<char> & boundaryflags,
size_t ts, size_t te, // range [ts,te)
MATRIX & v)
{
for (size_t t = ts; t < te; t++)
{
auto v_t = v.col(t-ts); // the vector to fill in
augmentneighbors (frames, boundaryflags, t, v_t);
}
}
// augment neighbor frames for a sequence of frames (part of an utterance, possibly spanning across boundaries)
template<class MATRIX> static void augmentneighbors(const std::vector<std::vector<float>> & frames, const std::vector<char> & boundaryflags, size_t leftextent, size_t rightextent,
size_t ts, size_t te, // range [ts,te)
MATRIX & v)
{
for (size_t t = ts; t < te; t++)
{
auto v_t = v.col(t - ts); // the vector to fill in
augmentneighbors(frames, boundaryflags, t, leftextent, rightextent, v_t);
}
}
// ---------------------------------------------------------------------------
// randomordering -- class to help manage randomization of input data
// ---------------------------------------------------------------------------
static inline size_t rand (const size_t begin, const size_t end)
{
const size_t randno = ::rand() * RAND_MAX + ::rand(); // BUGBUG: still only covers 32-bit range
return begin + randno % (end - begin);
}
class randomordering // note: NOT thread-safe at all
{
// constants for randomization
const static size_t randomizeAuto=0;
const static size_t randomizeDisable=(size_t)-1;
typedef unsigned int INDEXTYPE; // don't use size_t, as this saves HUGE amounts of RAM
std::vector<INDEXTYPE> map; // [t] -> t' indices in randomized order
size_t currentseed; // seed for current sequence
size_t randomizationrange; // t - randomizationrange/2 <= t' < t + randomizationrange/2 (we support this to enable swapping)
// special values (randomizeAuto, randomizeDisable)
void invalidate() { currentseed = (size_t) -1; }
public:
randomordering() { invalidate(); }
void resize (size_t len, size_t p_randomizationrange) { randomizationrange = p_randomizationrange>0?p_randomizationrange:len; map.resize (len); invalidate(); }
// return the randomized feature bounds for a time range
std::pair<size_t,size_t> bounds (size_t ts, size_t te) const
{
size_t tbegin = max (ts, randomizationrange/2) - randomizationrange/2;
size_t tend = min (te + randomizationrange/2, map.size());
return std::make_pair<size_t,size_t> (move(tbegin), move(tend));
}
// this returns the map directly (read-only) and will lazily initialize it for a given seed
const std::vector<INDEXTYPE> & operator() (size_t seed) //throw()
{
// if wrong seed then lazily recache the sequence
if (seed != currentseed)
{
// test for numeric overflow
if (map.size()-1 != (INDEXTYPE) (map.size()-1))
throw std::runtime_error ("randomordering: INDEXTYPE has too few bits for this corpus");
// 0, 1, 2...
foreach_index (t, map) map[t] = (INDEXTYPE) t;
// now randomize them
if (randomizationrange != randomizeDisable)
{
#if 1 // change to 0 to disable randomizing
if (map.size() > RAND_MAX * (size_t) RAND_MAX)
throw std::runtime_error ("randomordering: too large training set: need to change to different random generator!");
srand ((unsigned int) seed);
size_t retries = 0;
foreach_index (t, map)
{
for (int tries = 0; tries < 5; tries++)
{
// swap current pos with a random position
// Random positions are limited to t+randomizationrange.
// This ensures some locality suitable for paging with a sliding window.
const size_t tbegin = max ((size_t) t, randomizationrange/2) - randomizationrange/2; // range of window --TODO: use bounds() function above
const size_t tend = min (t + randomizationrange/2, map.size());
assert (tend >= tbegin); // (guard against potential numeric-wraparound bug)
const size_t trand = rand (tbegin, tend); // random number within windows
assert ((size_t) t <= trand + randomizationrange/2 && trand < (size_t) t + randomizationrange/2);
// if range condition is fulfilled then swap
if (trand <= map[t] + randomizationrange/2 && map[t] < trand + randomizationrange/2
&& (size_t) t <= map[trand] + randomizationrange/2 && map[trand] < (size_t) t + randomizationrange/2)
{
::swap (map[t], map[trand]);
break;
}
// but don't multi-swap stuff out of its range (for swapping positions that have been swapped before)
// instead, try again with a different random number
retries++;
}
}
fprintf (stderr, "randomordering: %zu retries for %zu elements (%.1f%%) to ensure window condition\n", retries, map.size(), 100.0 * retries / map.size());
// ensure the window condition
foreach_index (t, map) assert ((size_t) t <= map[t] + randomizationrange/2 && map[t] < (size_t) t + randomizationrange/2);
#if 1 // and a live check since I don't trust myself here yet
foreach_index (t, map) if (!((size_t) t <= map[t] + randomizationrange/2 && map[t] < (size_t) t + randomizationrange/2))
{
fprintf (stderr, "randomordering: windowing condition violated %d -> %d\n", t, map[t]);
throw std::logic_error ("randomordering: windowing condition violated");
}
#endif
#endif
#if 1 // test whether it is indeed a unique complete sequence
auto map2 = map;
::sort (map2.begin(), map2.end());
foreach_index (t, map2) assert (map2[t] == (size_t) t);
#endif
fprintf (stderr, "randomordering: recached sequence for seed %d: %d, %d, ...\n", (int) seed, (int) map[0], (int) map[1]);
}
currentseed = seed;
}
return map; // caller can now access it through operator[]
}
};
//typedef unsigned short CLASSIDTYPE; // type to store state ids; don't use size_t --saves HUGE amounts of RAM
typedef unsigned int CLASSIDTYPE; //mseltzer - change to unsigned int for untied context-dependent phones
};};

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -1,254 +0,0 @@
//
// <copyright file="numahelpers.h" company="Microsoft">
// Copyright (c) Microsoft Corporation. All rights reserved.
// </copyright>
//
// numahelpers.h -- some helpers with NUMA
#pragma once
#ifndef __unix__
#include <Windows.h>
#include "pplhelpers.h"
#endif
#include <stdexcept>
#include "simple_checked_arrays.h"
#include "basetypes.h" // for FormatWin32Error
namespace msra { namespace numa {
// ... TODO: this can be a 'static', as it should only be set during foreach_node but not outside
extern int node_override; // -1 = normal operation; >= 0: force a specific NUMA node
// force a specific NUMA node (only do this during single-threading!)
static inline void overridenode (int n = -1)
{
node_override = n;
}
// get the number of NUMA nodes we would like to distinguish
static inline size_t getnumnodes()
{
ULONG n;
if (!GetNumaHighestNodeNumber (&n)) return 1;
return n +1;
}
// execute body (node, i, n), i in [0,n) on all NUMA nodes in small chunks
template <typename FUNCTION> void parallel_for_on_each_numa_node (bool multistep, const FUNCTION & body)
{
// get our configuration
const size_t cores = ppl_cores;
assert (cores > 0);
const size_t nodes = getnumnodes();
const size_t corespernode = (cores -1) / nodes + 1;
// break into 8 steps per thread
const size_t stepspernode = multistep ? 16 : 1;
const size_t steps = corespernode * stepspernode;
// now run on many threads, hoping to hit all NUMA nodes, until we are done
hardcoded_array<LONG/*unsigned int*/,256> nextstepcounters; // next block to run for a NUMA node
if (nodes > nextstepcounters.size())
throw std::logic_error ("parallel_for_on_each_numa_node: nextstepcounters buffer too small, need to increase hard-coded size");
for (size_t k = 0; k < nodes; k++) nextstepcounters[k] = 0;
overridenode();
//unsigned int totalloops = 0; // for debugging only, can be removed later
msra::parallel::parallel_for (0, nodes * steps /*execute each step on each NUMA node*/, 1, [&](size_t /*dummy*/)
{
const size_t numanodeid = getcurrentnode();
// find a node that still has work left, preferring our own node
// Towards the end we will run on wrong nodes, but what can we do.
for (size_t node1 = numanodeid; node1 < numanodeid + nodes; node1++)
{
const size_t node = node1 % nodes;
const unsigned int step = InterlockedIncrement (&nextstepcounters[node]) -1; // grab this step
if (step >= steps) // if done then counter has exceeded the required number of steps
continue; // so try next NUMA node
// found one: execute and terminate loop
body (node, step, steps);
//InterlockedIncrement (&totalloops);
return; // done
}
// oops??
throw std::logic_error ("parallel_for_on_each_numa_node: no left-over block found--should not get here!!");
});
//assert (totalloops == nodes * steps);
}
// execute a passed function once for each NUMA node
// This must be run from the main thread only.
// ... TODO: honor ppl_cores == 1 for comparative measurements against single threads.
template<typename FUNCTION>
static void foreach_node_single_threaded (const FUNCTION & f)
{
const size_t n = getnumnodes();
for (size_t i = 0; i < n; i++)
{
overridenode ((int) i);
f();
}
overridenode (-1);
}
// get the current NUMA node
static inline size_t getcurrentnode()
{
// we can force it to be a certain node, for use in initializations
if (node_override >= 0)
return (size_t) node_override;
// actually use current node
DWORD i = GetCurrentProcessorNumber(); // note: need to change for >63 processors
UCHAR n;
if (!GetNumaProcessorNode ((UCHAR) i, &n)) return 0;
if (n == 0xff)
throw std::logic_error ("GetNumaProcessorNode() failed to determine NUMA node for GetCurrentProcessorNumber()??");
return n;
}
// allocate memory
// Allocation seems to be at least on a 512-byte boundary. We nevertheless verify alignment requirements.
typedef LPVOID (WINAPI *VirtualAllocExNuma_t) (HANDLE,LPVOID,SIZE_T,DWORD,DWORD,DWORD);
static VirtualAllocExNuma_t VirtualAllocExNuma = (VirtualAllocExNuma_t)-1;
static inline void * malloc (size_t n, size_t align)
{
// VirtualAllocExNuma() only exists on Vista+, so go through an explicit function pointer
if (VirtualAllocExNuma == (VirtualAllocExNuma_t)-1)
{
VirtualAllocExNuma = (VirtualAllocExNuma_t) GetProcAddress (GetModuleHandle ( TEXT ("kernel32.dll")), "VirtualAllocExNuma");
}
// if we have the function then do a NUMA-aware allocation
void * p;
if (VirtualAllocExNuma != NULL)
{
size_t node = getcurrentnode();
// "all Win32 heap allocations that are 1 MB or greater are forwarded directly to NtAllocateVirtualMemory
// when they are allocated and passed directly to NtFreeVirtualMemory when they are freed" Greg Colombo, 2010/11/17
if (n < 1024*1024)
n = 1024*1024; // -> brings NUMA-optimized code back to Node Interleave level (slightly faster)
p = VirtualAllocExNuma (GetCurrentProcess(), NULL, n, MEM_COMMIT | MEM_RESERVE, PAGE_READWRITE, (DWORD) node);
}
else // on old OS call no-NUMA version
{
p = VirtualAllocEx (GetCurrentProcess(), NULL, n, MEM_COMMIT | MEM_RESERVE, PAGE_READWRITE);
}
if (p == NULL)
fprintf (stderr, "numa::malloc: failed allocating %d bytes with alignment %d\n", n, align);
if (((size_t) p) % align != 0)
throw std::logic_error ("VirtualAllocExNuma() returned an address that does not match the alignment requirement");
return p;
}
// free memory allocated with numa::malloc()
static inline void free (void * p)
{
assert (p != NULL);
if (!VirtualFree (p, 0, MEM_RELEASE))
throw std::logic_error ("VirtualFreeEx failure");
}
// dump memory allocation
static inline void showavailablememory (const char * what)
{
size_t n = getnumnodes();
for (size_t i = 0; i < n; i++)
{
ULONGLONG availbytes = 0;
BOOL rc = GetNumaAvailableMemoryNode ((UCHAR) i, &availbytes);
const double availmb = availbytes / (1024.0*1024.0);
if (rc)
fprintf (stderr, "%s: %8.2f MB available on NUMA node %d\n", what, availmb, i);
else
fprintf (stderr, "%s: error '%S' for getting available memory on NUMA node %d\n", what, FormatWin32Error (::GetLastError()).c_str(), i);
}
}
// determine NUMA node with most memory available
static inline size_t getmostspaciousnumanode()
{
size_t n = getnumnodes();
size_t bestnode = 0;
ULONGLONG bestavailbytes = 0;
for (size_t i = 0; i < n; i++)
{
ULONGLONG availbytes = 0;
GetNumaAvailableMemoryNode ((UCHAR) i, &availbytes);
if (availbytes > bestavailbytes)
{
bestavailbytes = availbytes;
bestnode = i;
}
}
return bestnode;
}
#if 0 // this is no longer used (we now parallelize the big matrix products directly)
// class to manage multiple copies of data on local NUMA nodes
template<class DATATYPE,class CACHEDTYPE> class numalocaldatacache
{
numalocaldatacache (const numalocaldatacache&); numalocaldatacache & operator= (const numalocaldatacache&);
// the data set we associate to
const DATATYPE & data;
// cached copies of the models for NUMA
vector<unique_ptr<CACHEDTYPE>> cache;
// get the pointer to the clone for the NUMA node of the current thread (must exist)
CACHEDTYPE * getcloneptr()
{
return cache[getcurrentnode()].get();
}
public:
numalocaldatacache (const DATATYPE & data) : data (data), cache (getnumnodes())
{
foreach_node_single_threaded ([&]()
{
cache[getcurrentnode()].reset (new CACHEDTYPE (data));
});
}
// this takes the cached versions of the parent model
template<typename ARGTYPE1,typename ARGTYPE2,typename ARGTYPE3>
numalocaldatacache (numalocaldatacache<DATATYPE,DATATYPE> & parentcache, const ARGTYPE1 & arg1, const ARGTYPE2 & arg2, const ARGTYPE3 & arg3) : data (*(DATATYPE*)nullptr), cache (getnumnodes())
{
foreach_node_single_threaded ([&]()
{
const DATATYPE & parent = parentcache.getclone();
size_t numanodeid = getcurrentnode();
cache[numanodeid].reset (new CACHEDTYPE (parent, arg1, arg2, arg3));
});
}
// re-clone --update clones from the cached 'data' reference
// This is only valid if CACHEDTYPE==DATATYPE.
// ... parallelize this!
void reclone()
{
parallel_for_on_each_numa_node (true, [&] (size_t numanodeid, size_t step, size_t steps)
{
if (step != 0)
return; // ... TODO: tell parallel_for_on_each_numa_node() to only have one step, or parallelize
cache[numanodeid].get()->copyfrom (data); // copy it all over
});
}
// post-process all clones
// 'numanodeid' is ideally the current NUMA node most of the time, but not required.
template<typename POSTPROCFUNC>
void process (const POSTPROCFUNC & postprocess)
{
parallel_for_on_each_numa_node (true, [&] (size_t numanodeid, size_t step, size_t steps)
{
postprocess (*cache[numanodeid].get(), step, steps);
});
}
// a thread calls this to get the data pre-cloned for its optimal NUMA node
// (only works for memory allocated through msra::numa::malloc())
const CACHEDTYPE & getclone() const { return *getcloneptr(); }
CACHEDTYPE & getclone() { return *getcloneptr(); }
};
#endif
};};

Просмотреть файл

@ -1,99 +0,0 @@
//
// <copyright file="pplhelpers.h" company="Microsoft">
// Copyright (c) Microsoft Corporation. All rights reserved.
// </copyright>
//
// pplhelpers.h -- some helpers for PPL library
//
#pragma once
#ifndef __unix__
#include <ppl.h>
#endif
namespace msra { namespace parallel {
// ===========================================================================
// helpers related to multiprocessing and NUMA
// ===========================================================================
// determine number of CPU cores on this machine
static inline size_t determine_num_cores()
{
SYSTEM_INFO sysInfo;
GetSystemInfo (&sysInfo);
return sysInfo.dwNumberOfProcessors;
}
extern size_t ppl_cores; // number of cores to run on as requested by user
static inline void set_cores (size_t cores)
{
ppl_cores = cores;
}
static inline size_t get_cores() // if returns 1 then no parallelization will be done
{
return ppl_cores;
}
#if 0
// execute body() a bunch of times for hopefully each core
// This is not precise. Cores will be hit multiple times, and some cores may not be touched.
template <typename FUNCTION> void for_all_numa_nodes_approximately (const FUNCTION & body)
{
if (ppl_cores > 1) // parallel computation (regular)
parallel_for ((size_t) 0, ppl_cores * 2, (size_t) 1, [&](size_t) { body(); });
else // for comparison: single-threaded (this also documents what the above means)
body();
}
#endif
// wrapper around Concurrency::parallel_for() to allow disabling parallelization altogether
template <typename FUNCTION> void parallel_for (size_t begin, size_t end, size_t step, const FUNCTION & f)
{
const size_t cores = ppl_cores;
if (cores > 1) // parallel computation (regular)
{
//fprintf (stderr, "foreach_index_block: computing %d blocks of %d frames on %d cores\n", nblocks, nfwd, determine_num_cores());
Concurrency::parallel_for (begin, end, step, f);
}
else // for comparison: single-threaded (this also documents what the above means)
{
//fprintf (stderr, "foreach_index_block: computing %d blocks of %d frames on a single thread\n", nblocks, nfwd);
for (size_t j0 = begin; j0 < end; j0 += step) f (j0);
}
}
// execute a function 'body (j0, j1)' for j = [0..n) in chunks of ~targetstep in 'cores' cores
// Very similar to parallel_for() except that body function also takes end index,
// and the 'targetsteps' gets rounded a little to better map to 'cores.'
// ... TODO: Currently, 'cores' does not limit the number of threads in parallel_for() (not so critical, fix later or never)
template <typename FUNCTION> void foreach_index_block (size_t n, size_t targetstep, size_t targetalignment, const FUNCTION & body)
{
const size_t cores = ppl_cores;
const size_t maxnfwd = 2 * targetstep;
size_t nblocks = (n + targetstep / 2) / targetstep;
if (nblocks == 0) nblocks = 1;
// round to a multiple of the number of cores
if (nblocks < cores) // less than # cores -> round up
nblocks = (1+(nblocks-1)/cores) * cores;
else // more: round down (reduce overhead)
nblocks = nblocks / cores * cores;
size_t nfwd = 1 + (n - 1) / nblocks;
assert (nfwd * nblocks >= n);
if (nfwd > maxnfwd) nfwd = maxnfwd; // limit to allocated memory just in case
// ... TODO: does the above actually do anything/significant? nfwd != targetstep?
// enforce alignment
nfwd = (1 + (nfwd -1) / targetalignment) * targetalignment;
// execute it!
parallel_for (0, n, nfwd, [&](size_t j0)
{
size_t j1 = min (j0 + nfwd, n);
body (j0, j1);
});
}
};};

Просмотреть файл

@ -1,249 +0,0 @@
//
// <copyright file="readaheadsource.h" company="Microsoft">
// Copyright (c) Microsoft Corporation. All rights reserved.
// </copyright>
//
// readaheadsource.h -- wrapper ('minibatchreadaheadsource') of a read-ahead thread that pre-rolls feature and lattice data
//
#pragma once
#include "basetypes.h"
#include "minibatchiterator.h"
#include "latticearchive.h"
#ifdef _WIN32
#include "simplethread.h"
#endif
#include <deque>
#include <stdexcept>
namespace msra { namespace dbn {
// ---------------------------------------------------------------------------
// minibatchreadaheadsource -- read-ahead thread that pre-rolls feature and lattice data
// ---------------------------------------------------------------------------
class minibatchreadaheadsource : public minibatchsource/*the interface we implement*/,
noncopyable/*assignment operator needed somewhere*/,
CCritSec/*for multi-threaded access*/
{
minibatchsource & source; // the underlying source we read from
const size_t epochframes; // epoch size
unique_ptr<msra::util::simplethread> thread;
int verbosity;
// the FIFO
struct batchdata // all arguments to/from getbatch
{
size_t globalts; // time for which we get the data
// return values
msra::dbn::matrix feat;
std::vector<size_t> uids;
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> transcripts;
std::vector<shared_ptr<const latticesource::latticepair>> lattices;
batchdata (size_t globalts) : globalts (globalts) { }
};
deque<batchdata> fifo; // this is guarded by the CCritSec
size_t epoch; // which epoch we are in currently
// parameters for the thread proc (set by caller; taken over once newglobalts is set to non-SIZE_MAX (cleared back by thread))
volatile size_t newglobalts; // reset request
volatile size_t currentepochreqframes; // minibatch size for this epoch (taken from the first getbatch() call)
volatile size_t currentepochendframe; // we cannot request beyond
// signalling
mutable msra::util::signallingevent callerchangedsignal, threadchangedsignal;
void waitcallerchanged() const { callerchangedsignal.wait(); }
void flagcallerchanged() const { callerchangedsignal.flag(); }
void waitthreadchanged() const { threadchangedsignal.wait(); }
void flagthreadchanged() const { threadchangedsignal.flag(); }
// the thread proc
volatile bool terminaterequest; // threadproc must respond to this
size_t globalts; // read cursor, owned by thread only
void threadproc()
{
// note on signaling:
// This thread will always flag 'threadchangedsignal' if there is a state change,
// e.g. a new batch is available, or we have successfully initialized.
// The main ('caller') thread would check whether it finds a state it can make use of, and if not,
// it will wait for the 'threadchangedsignal' and then check again the state etc.
fprintf (stderr, "minibatchreadaheadsource: read-ahead thread entered\n");
try
{
size_t epochreqframes = 0; // minibatch size for this epoch (taken from the first getbatch() call)
size_t epochendframe = 0; // we cannot request beyond
size_t globalts = 0; // reset request
while (!terminaterequest)
{
bool stillhasdata;
{
CAutoLock lock (*this);
// if reset request then do it
if (newglobalts != SIZE_MAX)
{
// take over parameters from caller
globalts = newglobalts;
epochreqframes = currentepochreqframes;
epochendframe = currentepochendframe;
newglobalts = SIZE_MAX; // remember we got it
// reset the FIFO
fifo.clear();
flagthreadchanged(); // signal state change (needed?)
fprintf (stderr, "minibatchreadaheadsource: thread entered new epoch, frame pos reset to %d\n", (int) globalts);
continue;
}
// did we run out of data to give to the caller?
stillhasdata = !fifo.empty();
}
// we kick in once the FIFO is empty (and only once we know the mbsize)
// Note that the underlying source will be able to fulfill many more minibatches at no cost
// since we stopped pulling minibatches from it once it told us it read something from the disk.
// Thus it is OK (efficient) to run the FIFO empty before we continue asking the underlying source
// for more data--it will give us quite some more data for free--which the caller can go and process--
// before an expensive read operation is needed again.
if (globalts >= epochendframe || stillhasdata)
{
waitcallerchanged(); // nothing to do: wait for caller state change and check again
continue;
}
// we will bring in data from the current 'globalts' until the sub-getbatch() tells us
// that we loaded new data (which means subsequent getbatch() will be free until the next load).
// We assume the access pattern that
// - we start at or closely after the epoch boundary
// - we never go across an epoch boundary
// - the number of requested frames within an epoch is always the same except for the last MB
// This pattern is implemented by the minibatchiterator. We require it.
// (but it is possible that less is returned, i.e. at a sweep boundary or epoch end).
bool readfromdisk = false;
// we stop once data was read (the subsequent fetches will be cheap until the next data read)
// For small setups, all data may be in RAM and thus no reading will happen anymore.
// To guard against that, we limit the number of frames we pre-read.
fprintf (stderr, "minibatchreadaheadsource: thread entering reading loop, frame read pos %d\n", (int) globalts);
size_t batchesread = 0;
const size_t prerollendframe = globalts + 360000; // read max. 1 hour --to guard against setups that fit to RAM entirely (no disk reading after startup)
while (!terminaterequest && !readfromdisk && globalts < epochendframe && globalts < prerollendframe)
{
// get batch and append to FIFO (outside the lock)
batchdata batch (globalts);
const size_t requestedframes = min (epochreqframes, epochendframe - globalts); // we must not request beyond the epoch
readfromdisk = source.getbatch (globalts, requestedframes, batch.feat, batch.uids, batch.transcripts, batch.lattices);
batchesread++;
// Note: We may still get data beyond the end of the epoch, in utterance mode, since the epoch boundary likely falls within an utterance.
CAutoLock lock (*this);
if (!fifo.empty() && globalts != fifo.back().globalts + fifo.back().feat.cols())
throw std::logic_error ("minibatchreadaheadsource: FIFO got out of order while pre-reading new batch");
if (newglobalts != SIZE_MAX)
throw std::logic_error ("minibatchreadaheadsource: main thread reset to new epoch while current epoch not yet finished");
globalts += batch.feat.cols();
fifo.push_back (std::move (batch));
flagthreadchanged(); // signal state change so caller can pick up the new batch
}
fprintf (stderr, "minibatchreadaheadsource: thread exited reading loop, %d batches read up to frame position %d-1\n", (int) batchesread, (int) globalts);
}
fprintf (stderr, "minibatchreadaheadsource: reading loop was terminated at frame position %d-1\n", (int) globalts);
}
catch (const exception & e)
{
fprintf (stderr, "minibatchreadaheadsource: exception caught in read-ahead thread: %s\n", e.what());
thread->fail (e); // set the error first before we signal the caller
flagthreadchanged();
throw; // (this will set the error a second time; OK)
}
fprintf (stderr, "minibatchreadaheadsource: read-ahead thread exited normally\n");
}
void cancelthread() // this is only ever called by the destructor
{
fprintf (stderr, "minibatchreadaheadsource: requesting thread termination\n");
terminaterequest = true;
flagcallerchanged();
thread->wait();
}
public:
minibatchreadaheadsource (minibatchsource & source, size_t epochframes)
: source (source), epochframes (epochframes),
terminaterequest (false), globalts (SIZE_MAX),
epoch (SIZE_MAX), currentepochreqframes (0), currentepochendframe (0), newglobalts (SIZE_MAX), verbosity(2)
{
// kick off the thread
fprintf (stderr, "minibatchreadaheadsource: kicking off read-ahead thread\n");
thread.reset (new msra::util::simplethread ([this] () { threadproc(); }));
}
~minibatchreadaheadsource()
{
fprintf (stderr, "~minibatchreadaheadsource: destructing read-ahead thread\n");
cancelthread();
}
void setverbosity(int newverbosity){ verbosity = newverbosity; }
bool getbatch (const size_t globalts,
const size_t framesrequested, msra::dbn::matrix & feat, std::vector<size_t> & uids,
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> & transcripts,
std::vector<shared_ptr<const latticesource::latticepair>> & lattices)
{
#if 1
// first check whether the thread is still alive
thread->check();
// in case of epoch change, we signal the thread
size_t thisepoch = globalts / epochframes;
if (thisepoch != epoch)
{
fprintf (stderr, "minibatchreadaheadsource: signalling thread to enter new epoch\n");
epoch = thisepoch; // remember for next check --we have officially changed epochs
CAutoLock lock (*this);
if (!fifo.empty())
throw std::logic_error ("getbatch: FIFO not cleared at end of epoch");
newglobalts = globalts;
currentepochreqframes = framesrequested; // it is assumed that these won't change
currentepochendframe = (epoch + 1) * epochframes;
flagcallerchanged();
}
else if (globalts + framesrequested < currentepochendframe && currentepochreqframes != framesrequested)
throw std::logic_error ("getbatch: cannot change minibatch size mid-epoch");
// loop
bool readfromdisk = false;
for(;;) // wait for batch to appear
{
thread->check();
{
CAutoLock lock (*this);
if (!fifo.empty())
{
// get the first batch from the FIFO
batchdata front = std::move (fifo.front());
fifo.pop_front();
flagcallerchanged();
// it must be the correct one
if (front.globalts != globalts)
throw std::logic_error ("getbatch: data in FIFO out of sequence");
// return it
feat = std::move (front.feat);
uids = std::move (front.uids);
transcripts = std::move (front.transcripts);
lattices = std::move (front.lattices);
return readfromdisk;
}
}
// batch not there --keep looping
waitthreadchanged();
readfromdisk = true; // we had to wait --use to indicate that we needed to read data (does not really matter...)
}
#else
return source.getbatch (globalts, framesrequested, feat, uids, transcripts, lattices);
#endif
}
bool getbatch (const size_t globalts,
const size_t framesrequested, std::vector<msra::dbn::matrix> & feat, std::vector<std::vector<size_t>> & uids,
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> & transcripts,
std::vector<shared_ptr<const latticesource::latticepair>> & lattices)
{
feat.resize(1);
uids.resize(1);
//transcripts.resize(1);
//lattices.resize(1);
return getbatch(globalts, framesrequested, feat[0], uids[0], transcripts, lattices);
}
size_t totalframes() const { return source.totalframes(); }
size_t epochsize() const {return epochframes;}double gettimegetbatch() { return source.gettimegetbatch(); } // TODO: no, use our own time measurement
size_t firstvalidglobalts (const size_t globalts) { return source.firstvalidglobalts (globalts); }
const std::vector<size_t> & unitcounts() const { return source.unitcounts(); }
};
};};

Просмотреть файл

@ -1,827 +0,0 @@
//
// <copyright file="rollingwindowsource.h" company="Microsoft">
// Copyright (c) Microsoft Corporation. All rights reserved.
// </copyright>
//
// rollingwindowsource.h -- implementation of a rolling-window minibatch source ('minibatchframesource') with a disk page file
//
#pragma once
#include "basetypes.h" // for attempt()
//#include "numahelpers.h" // for NUMA allocation
#include "minibatchsourcehelpers.h"
#include "minibatchiterator.h"
#include "biggrowablevectors.h"
#include "ssematrix.h"
namespace msra { namespace dbn {
// ---------------------------------------------------------------------------
// biggrowablevectorarray -- a big array of vectors for features, growable (push_back)
// Data is striped across NUMA nodes, as to not clog them up.
// This also supports paging to disk, which is used for the old minibatchframesource.
// ---------------------------------------------------------------------------
class biggrowablevectorarray : public growablevectorbase<msra::dbn::matrix>
{
size_t m; // dim
size_t inmembegin; // range we have in memory, rounded to enclosing blocks (not rounded at end)
size_t inmemend;
wstring pagepath; // path for paging, empty if no paging
auto_file_ptr f; // file handle for paging
bool reading; // have we begun reading?
// allocate a block
msra::dbn::matrix * newblock() const
{
// we stripe the data across NUMA nodes as to not fill up one node with the feature data
//msra::numa::overridenode ((int) msra::numa::getmostspaciousnumanode());
msra::dbn::matrix * res = new msra::dbn::matrix (m, elementsperblock);
//msra::numa::overridenode (-1); // note: we really should reset it also in case of failure
return res;
}
// handling of page file
bool paging() const { return !pagepath.empty(); }
void openpagefile (bool wantread)
{
if (!paging()) return;
msra::files::make_intermediate_dirs (pagepath);
if (!wantread)
{
FILE *ftry = NULL;
wstring pathname (pagepath);
ftry = _wfopen (pathname.c_str(), L"wbS");
if (ftry) fclose (ftry);
}
/*
code below to cycle through a-z appended to file name is no longer necessary
since caller guarantees unique file names via HTKMLFReader
and we want the pagepath logged to the user to be the actual one used by the code
// try to open the pagepath from a to z
if (!wantread)
{
FILE *ftry = NULL;
char trynum = 'a';
while (!ftry && trynum <= 'z')
{
wstring pathname (pagepath);
pathname += trynum++;
ftry = _wfopen (pathname.c_str(), L"wbS");
}
if (ftry) fclose (ftry);
pagepath += --trynum;
}
*/
f = fopenOrDie (pagepath, wantread ? L"rbS" : L"wbS");
reading = wantread;
}
void flushlastblock() // during population phase, must be called once per block in sequence
{
if (!paging()) return;
assert (!reading);
if (blocks.empty()) return;
const size_t blockid = blocks.size() -1;
msra::dbn::matrix & block = *blocks[blockid];
assert (fgetpos (f) == blockid * block.sizeinpagefile());
block.topagefile (f);
blocks[blockid].reset(); // free the memory
assert (blockid * elementsperblock == inmembegin);
inmembegin = inmemend; // empty range
}
void releaseblock (size_t t0) // t0=block start time
{
assert (paging() && reading);
size_t blockid = t0 / elementsperblock;
assert (blockid * elementsperblock == t0);
assert (blocks[blockid]);
fprintf (stderr, "recoverblock: releasing feature block %zu [%zu..%zu)\n", blockid, t0, t0 + elementsperblock -1);
blocks[blockid].reset(); // free the memory
}
void recoverblock (size_t t0) // t0=block start time
{
assert (paging() && reading);
size_t blockid = t0 / elementsperblock;
assert (blockid * elementsperblock == t0);
assert (!blocks[blockid]);
fprintf (stderr, "recoverblock: recovering feature block %zu [%zu..%zu)\n", blockid, t0, t0 + elementsperblock -1);
blocks[blockid].reset (newblock());
msra::dbn::matrix & block = *blocks[blockid];
fsetpos (f, blockid * block.sizeinpagefile());
block.frompagefile (f);
}
public:
biggrowablevectorarray (const wstring & pagepath)
: growablevectorbase (65536), m (0),
inmembegin (0), inmemend (0), pagepath (pagepath), reading (false)
{
openpagefile (false);
if (paging())
fprintf (stderr, "biggrowablevectorarray: creating disk backup store at '%S'\n", pagepath.c_str());
}
~biggrowablevectorarray() { // clean up the big temp file
if (paging()) {
fclose (f);
if (_wunlink (pagepath.c_str())==0)
fprintf (stderr, "biggrowablevectorarray: deleted disk backup store at '%S'\n", pagepath.c_str());
else
fprintf (stderr, "biggrowablevectorarray: unable to delete disk backup store at '%S'\n", pagepath.c_str());
}
}
size_t dim() const { return m; } // dimension of a frame
// reading phase
void push_back (const std::vector<float> & in)
{
assert (!in.empty());
assert (m == 0 || m == in.size());
m = in.size();
const size_t blockid = n / elementsperblock;
assert (blockid <= blocks.size());
if (blockid == blocks.size()) // a new block is needed
{
flushlastblock();
blocks.push_back (std::unique_ptr<msra::dbn::matrix> (newblock()));
}
const size_t blockn = n % elementsperblock;
msra::dbn::matrix & block = *blocks[blockid].get();
foreach_index (k, in)
block(k,blockn) = in[k];
n++;
inmemend = n;
}
void no_more_push_back() // done pushing --switch to consumption mode
{
if (!paging()) return;
// finish off last block
flushlastblock();
fflushOrDie (f);
fprintf (stderr, "biggrowablevectorarray: disk backup store created, %d frames, %zu bytes\n", (int) n, fgetpos (f));
fclose (f);
foreach_index (i, blocks) assert (!blocks[i]); // ensure we flushed
assert (inmembegin == inmemend); // nothing in cache
// switch to reading mode
openpagefile (true);
}
// access phase
// Returns 'true' if data was actually read from disk.
bool require (pair<size_t,size_t> bounds) // we require this range of frames
{
bool readfromdisk = false;
// get bounds rounded to block boundaries
const size_t ts = bounds.first / elementsperblock * elementsperblock;
const size_t te = min (n, (bounds.second + elementsperblock -1) / elementsperblock * elementsperblock);
assert (paging());
// free all the memmory
for (size_t t = inmembegin; t < inmemend; t += elementsperblock)
{
if (t >= ts && t < te) // if in wanted range then skip to end of it
t = te - elementsperblock;
else
releaseblock (t);
}
// page in all required blocks
for (size_t t = ts; t < te; t += elementsperblock)
{
if (t >= inmembegin && t < inmemend) // if in memory already then skip to end of it
t = inmemend - elementsperblock;
else
{
recoverblock (t);
readfromdisk = true; // tell caller we did something expensive
}
}
// got it
inmembegin = ts;
inmemend = te;
return readfromdisk;
}
const msra::dbn::matrixstripe operator[] (size_t t) const // get a feature vector
{
if (t < inmembegin || t >= inmemend)
throw std::logic_error ("biggrowablevectorarray: attempt to access vector without requesting to page it in first");
const size_t blockt = getblockt (t);
/*const*/ msra::dbn::matrix & block = getblock (t);
return msra::dbn::matrixstripe (block, blockt, 1);
}
wstring pagepathname(){ return pagepath;}
void cleanuppagefile()
{
if (paging()) {
fclose (f);
if (_wunlink (pagepath.c_str())==0){
fprintf (stderr, "biggrowablevectorarray: deleted disk backup store at '%S'\n", pagepath.c_str());
}
else{
fprintf (stderr, "biggrowablevectorarray: could NOT delete disk backup store at '%S'\n", pagepath.c_str());
}
}
}
};
// ---------------------------------------------------------------------------
// minibatchframesource -- feature source to provide randomized frames in minibatches
// This is the old code that pages all frames to a huge disk file first.
// (The new minibatchutterancesource pages from input files directly and can also
// operate in utterance mode for MMI training.)
// ---------------------------------------------------------------------------
class minibatchframesource : public minibatchsource
{
size_t vdim; // feature dimension after augmenting neighhors (0: don't read features)
unsigned int sampperiod; // (for reference and to check against model)
string featkind;
size_t featdim;
// cache
biggrowablevectorarray frames; // [t][i] all features concatenated
std::vector<char> boundaryflags; // [t] -1 for first and +1 for last frame, 0 else (for augmentneighbors())
std::vector<CLASSIDTYPE> classids; // [t] the state that the frame belongs to
size_t numframes; // total frames (==frames.size()==boundaryflags.size()==classids.size()) unless special modes vdim == 0 and/or no labels
msra::dbn::randomordering randomordering; // [t] -> t'
double timegetbatch;
int verbosity;
public:
// constructor
// Pass empty labels to denote unsupervised training (so getbatch() will not return uids).
minibatchframesource (const std::vector<wstring> & infiles, const map<wstring,std::vector<msra::asr::htkmlfentry>> & labels,
size_t vdim, size_t udim, size_t randomizationrange, const wstring & pagepath, const bool mayhavenoframe=false, int addEnergy=0)
: vdim (vdim), sampperiod (0), featdim (0), numframes (0), frames (pagepath), timegetbatch (0), verbosity(2)
{
if (vdim == 0 && labels.empty())
throw runtime_error ("minibatchframesource: when running without features, labels are needed");
// at this stage, we simply page in the entire training set at once and work off RAM
// We will benefit from feature archives indirectly through htkfeatio.
// TODO:
// - infiles must specify time range
// - at this stage only reserve() (we know the time range; allocate second-layer structure)
// - implement block-wise paging directly from HTK feature files through htkfeatreader
featkind.clear();
std::vector<float> frame;
fprintf (stderr, "minibatchframesource: reading %zu utterances..", infiles.size());
size_t numclasses = 0; // number of units found (actually max id +1)
size_t notfound = 0; // number of entries missing in MLF
msra::asr::htkfeatreader reader; // feature reader
reader.AddEnergy(addEnergy);
foreach_index (i, infiles)
{
if (i % (infiles.size() / 100 + 1) == 0) { fprintf (stderr, "."); fflush (stderr); }
msra::basetypes::matrix<float> feat;
msra::asr::htkfeatreader::parsedpath ppath (infiles[i]);
// skip files for which labels don't exist (assuming bad alignment)
wstring key;
if (!labels.empty()) // empty means unsupervised mode (don't load any)
{
#ifdef _WIN32
key = regex_replace ((wstring)ppath, wregex (L"\\.[^\\.\\\\/:]*$"), wstring()); // delete extension (or not if none)
#endif
#ifdef __unix__
key = removeExtension(basename(ppath));
#endif
if (labels.find (key) == labels.end())
{
if (notfound < 5)
fprintf (stderr, "\nminibatchframesource: %d-th file not found in MLF label set: %S", i, key.c_str());
notfound++;
continue; // skip this utterance at all
}
}
// get feature frames
if (vdim != 0) // (vdim == special mode to not read features at all)
{
msra::util::attempt (5, [&]()
{
reader.read (ppath, featkind, sampperiod, feat); // whole file read as columns of feature vectors
});
if (featdim == 0) // first time
featdim = feat.rows();
else if (featdim != feat.rows())
throw std::runtime_error ("minibatchframesource: inconsistent feature dimension across files");
// HVite occasionally generates mismatching output --skip such files
if (!key.empty()) // (we have a key if supervised mode)
{
const auto & labseq = labels.find (key)->second; // (we already checked above that it exists)
size_t labframes = labseq.empty() ? 0 : (labseq[labseq.size()-1].firstframe + labseq[labseq.size()-1].numframes);
if (abs ((int) labframes - (int) feat.cols()) > 0)
{
fprintf (stderr, "\nminibatchframesource: %d-th file has small duration mismatch (%zu in label vs. %zu in feat file), skipping: %S", i, labframes, feat.cols(), key.c_str());
notfound++;
continue; // skip this utterance at all
}
}
// append to cache
frame.resize (featdim);
if (feat.cols() < 2) // (2 frames needed for boundary markers)
throw std::runtime_error ("minibatchframesource: utterances < 2 frames not supported");
foreach_column (t, feat)
{
foreach_index (k, frame)
frame[k] = feat(k,t);
frames.push_back (frame);
numframes++;
boundaryflags.push_back ((t == 0) ? -1 : (t == feat.cols() -1) ? +1 : 0);
}
assert (numframes == frames.size());
assert (numframes == boundaryflags.size());
}
// get label sequence
if (!key.empty()) // (we have a key if supervised mode)
{
const auto & labseq = labels.find (key)->second; // (we already checked above that it exists)
foreach_index (i, labseq)
{
const auto & e = labseq[i];
if ((i > 0 && labseq[i-1].firstframe + labseq[i-1].numframes != e.firstframe) || (i == 0 && e.firstframe != 0))
throw std::runtime_error (msra::strfun::strprintf ("minibatchframesource: labels not in consecutive order MLF in label set: %S", key.c_str()));
for (size_t t = e.firstframe; t < e.firstframe + e.numframes; t++)
{
if (e.classid >= udim)
throw std::runtime_error (msra::strfun::strprintf ("minibatchframesource: class id exceeds model dimension in file %S", key.c_str()));
if (e.classid != (CLASSIDTYPE) e.classid)
throw std::runtime_error ("CLASSIDTYPE has too few bits");
classids.push_back ((CLASSIDTYPE) e.classid);
numclasses = max ((size_t)numclasses, (size_t)(1u + e.classid));
}
}
if (vdim == 0)
numframes = classids.size();
if (numframes != classids.size()) // TODO: remove this once we are confident
throw std::runtime_error (msra::strfun::strprintf ("minibatchframesource: label duration inconsistent with feature file in MLF label set: %S", key.c_str()));
assert (numframes == classids.size());
}
else
{
assert (classids.empty()); // that's how we detect it later
}
}
assert (vdim == 0 || numframes == frames.size());
assert (labels.empty() || numframes == classids.size());
if ((vdim != 0 && numframes != frames.size()) || (!labels.empty() && numframes != classids.size()))
throw std::runtime_error ("minibatchframesource: numframes variable screwup");
fprintf (stderr, " %zu frames read from %zu utterances; %zu classes\n", numframes, infiles.size(), numclasses);
if (notfound > 0)
{
fprintf (stderr, "minibatchframesource: %zu files out of %zu not found in label set\n", notfound, infiles.size());
if (notfound > infiles.size() / 2)
throw std::runtime_error ("minibatchframesource: too many files not found in label set--assuming broken configuration\n");
}
if (numframes == 0 && !mayhavenoframe)
throw std::runtime_error ("minibatchframesource: no input features given!");
// notify frames source to switch from population to consumption mode
frames.no_more_push_back();
// initialize randomizer
if (numframes > 0)
randomordering.resize (numframes, randomizationrange);
}
virtual ~minibatchframesource() {}
size_t totalframes() const { assert (vdim == 0 || numframes == frames.size()); assert (!issupervised() || numframes == classids.size()); return numframes; }
bool issupervised() const { return !classids.empty(); }
void setverbosity(int newverbosity) { verbosity = newverbosity; }
// retrieve one minibatch
// Minibatches are deterministic pseudo-random samples. The entire corpus
// is repeated infinitely, but each repetition (a 'sweep') is randomized
// differently.
// This function allows to retrieve a mini-batch starting from any frame
// within this infinitely extended repetition. To the end, mini-batches are
// specified by start frame and #frames.
// This function returns the same data independent on #frames, i.e. the concept
// of the mini-batch is not defined in here, but on the caller side. The caller
// can retrieve the frames of a mini-batch in chunks that do not match the
// caller's definition of "mini-batch," e.g. bigger or smaller chunks.
// If a requested mini-batch spans a sweep boundary, then this function will
// not return samples after the sweep boundary. Instead, the returned frame
// set is shortened to not exceed the end of the sweep. The caller must make
// a separate second call to get the rest. In trainlayer(), the one
// sweep-boundary-spanning mini-batch will simply be shortened.
// This function is NOT thread-safe (due to caching of random sequence).
bool getbatch (const size_t globalts, const size_t framesrequested, msra::dbn::matrix & feat, std::vector<size_t> & uids,
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> & transcripts,
std::vector<shared_ptr<const latticesource::latticepair>> & latticepairs)
{
auto_timer timergetbatch;
transcripts.clear(); // word-level transcripts not supported by frame source (aimed at MMI)
latticepairs.clear(); // neither are lattices
assert (totalframes() > 0);
const size_t sweep = globalts / totalframes(); // which sweep (this determines randomization)
const size_t ts = globalts % totalframes(); // start frame within the sweep
const size_t te = min (ts + framesrequested, totalframes()); // do not go beyond sweep boundary
assert (te > ts);
if (verbosity >= 2)
fprintf (stderr, "getbatch: frames [%zu..%zu] in sweep %zu\n", ts, te-1, sweep);
// get random sequence (each time index occurs exactly once)
// If the sweep changes, this will re-cache the sequence. We optimize for rare, monotonous sweep changes.
const auto & tmap = randomordering (sweep);
// page in the needed range of frames
const size_t extent = augmentationextent (frames.dim(), vdim);
bool readfromdisk = frames.require (randomordering.bounds (max (ts, extent) - extent, te + 1 + extent));
// generate features and uids
feat.resize (vdim, te - ts); // note: special mode vdim == 0 means no features to be loaded
if (issupervised()) // empty means unsupervised training -> return empty uids
uids.resize (te - ts);
else
uids.clear();
for (size_t t = ts; t < te; t++)
{
size_t trand = tmap[t]; // the random-sequence sample point for this point in time
if (vdim != 0)
{
auto v_t = feat.col(t-ts); // the vector to fill in
augmentneighbors (frames, boundaryflags, trand, v_t);
}
if (issupervised())
uids[t-ts] = classids[trand];
}
timegetbatch = timergetbatch;
return readfromdisk;
}
bool getbatch (const size_t globalts, const size_t framesrequested, std::vector<msra::dbn::matrix> & feat, std::vector<std::vector<size_t>> & uids,
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> & transcripts,
std::vector<shared_ptr<const latticesource::latticepair>> & latticepairs)
{
// for single input/output set size to be 1 and run old getbatch
feat.resize(1);
uids.resize(1);
//transcripts.resize(1);
//latticepairs.resize(1);
return getbatch(globalts, framesrequested, feat[0], uids[0], transcripts, latticepairs);
}
double gettimegetbatch () { return timegetbatch;}
// return first valid globalts to ask getbatch() for
// In frame mode, there is no constraint, i.e. it is 'globalts' itself.
/*implement*/ size_t firstvalidglobalts (const size_t globalts) { return globalts; }
/*implement*/ const std::vector<size_t> & unitcounts() const { throw logic_error ("unitcounts: not implemented for this feature source"); static std::vector<size_t> x; return x;/*keep compiler happy*/ }
};
// ---------------------------------------------------------------------------
// minibatchframesourcemulti -- feature source to provide randomized frames in minibatches
// this is derived from minibatchframesource but worked with multiple inputs and/or outputs
// by making "frames" and "classids" a vector of vectors
// ---------------------------------------------------------------------------
class minibatchframesourcemulti : public minibatchsource
{
std::vector<size_t> vdim; // feature dimension after augmenting neighhors (0: don't read features)
std::vector<size_t> leftcontext; // number of frames to the left of the target frame in the context window
std::vector<size_t> rightcontext; // number of frames to the right of the target frame in the context window
unsigned int sampperiod; // (for reference and to check against model)
string featkind;
size_t featdim;
size_t maxvdim;
// cache
//std::vector<biggrowablevectorarray> frames;
std::vector<unique_ptr<biggrowablevectorarray>> pframes; // [t][i] all features concatenated
std::vector<char> boundaryflags; // [t] -1 for first and +1 for last frame, 0 else (for augmentneighbors())
std::vector<std::vector<CLASSIDTYPE>> classids; // [t] the state that the frame belongs to
size_t numframes; // total frames (==frames.size()==boundaryflags.size()==classids.size()) unless special modes vdim == 0 and/or no labels
msra::dbn::randomordering randomordering; // [t] -> t'
double timegetbatch;
int verbosity;
public:
// constructor
// Pass empty labels to denote unsupervised training (so getbatch() will not return uids).
minibatchframesourcemulti (const std::vector<std::vector<wstring>> & infiles, const std::vector<map<std::wstring,std::vector<msra::asr::htkmlfentry>>> & labels,
std::vector<size_t> vdim, std::vector<size_t> udim, std::vector<size_t> leftcontext, std::vector<size_t> rightcontext, size_t randomizationrange, const std::vector<wstring> & pagepath, const bool mayhavenoframe=false, int addEnergy=0)
: vdim (vdim), leftcontext(leftcontext), rightcontext(rightcontext), sampperiod (0), featdim (0), numframes (0), timegetbatch (0), verbosity(2), maxvdim(0)
{
if (vdim[0] == 0 && labels.empty())
throw runtime_error ("minibatchframesourcemulti: when running without features, labels are needed");
// at this stage, we simply page in the entire training set at once and work off RAM
// We will benefit from feature archives indirectly through htkfeatio.
// TODO:
// - infiles must specify time range
// - at this stage only reserve() (we know the time range; allocate second-layer structure)
// - implement block-wise paging directly from HTK feature files through htkfeatreader
featkind.clear();
std::vector<float> frame;
std::vector<size_t>numclasses; // number of units found (actually max id +1)
size_t notfound = 0; // number of entries missing in MLF
std::vector<size_t>framesaccum;
if (infiles.size()==0)
throw runtime_error("minibatchframesourcemulti: need at least one network input specified with features");
if (labels.size()==0)
fprintf(stderr,"no MLF label files detected\n");
foreach_index (i, infiles)
{
pframes.push_back(unique_ptr<biggrowablevectorarray>(new biggrowablevectorarray(pagepath[i])));
if (vdim[i]>maxvdim)
maxvdim=vdim[i];
}
foreach_index (i, labels)
{
classids.push_back(std::vector<CLASSIDTYPE>());
numclasses.push_back(0);
}
fprintf (stderr, "minibatchframesourcemulti: reading %zu feature sets and %zu label sets...", infiles.size(),labels.size());
foreach_index (m, infiles)
{
featdim=0;
numframes=0;
featkind.clear();
msra::asr::htkfeatreader reader; // feature reader
reader.AddEnergy(addEnergy);
foreach_index (i, infiles[m]) // read each feature file in set m
{
if (i % (infiles[m].size() / 100 + 1) == 0) { fprintf (stderr, "."); fflush (stderr); }
msra::basetypes::matrix<float> feat;
msra::asr::htkfeatreader::parsedpath ppath (infiles[m][i]);
// skip files for which labels don't exist (assuming bad alignment)
wstring key;
if (!labels.empty())
{
if (!labels[0].empty()) // empty means unsupervised mode (don't load any)
{
#ifdef _WIN32
key = regex_replace ((wstring)ppath, wregex (L"\\.[^\\.\\\\/:]*$"), wstring()); // delete extension (or not if none)
#endif
#ifdef __unix__
key = removeExtension(basename(ppath));
#endif
if (labels[0].find (key) == labels[0].end())
{
if (notfound < 5)
fprintf (stderr, "\nminibatchframesourcemulti: %d-th file not found in MLF label set: %S", i, key.c_str());
notfound++;
continue; // skip this utterance at all
}
}
}
// get feature frames
if (vdim[m] != 0) // (vdim == special mode to not read features at all)
{
msra::util::attempt (5, [&]()
{
reader.read (ppath, featkind, sampperiod, feat); // whole file read as columns of feature vectors
});
if (featdim == 0) // first time
featdim = feat.rows();
else if (featdim != feat.rows())
throw std::runtime_error ("minibatchframesourcemulti: inconsistent feature dimension across files");
// HVite occasionally generates mismatching output --skip such files
if (!key.empty()) // (we have a key if supervised mode)
{
const auto & labseq = labels[0].find (key)->second; // (we already checked above that it exists)
size_t labframes = labseq.empty() ? 0 : (labseq[labseq.size()-1].firstframe + labseq[labseq.size()-1].numframes);
if (abs ((int) labframes - (int) feat.cols()) > 0)
{
fprintf (stderr, "\nminibatchframesourcemulti: %d-th file has small duration mismatch (%zu in label vs. %zu in feat file), skipping: %S", i, labframes, feat.cols(), key.c_str());
notfound++;
continue; // skip this utterance at all
}
}
// append to cache
frame.resize (featdim);
if (feat.cols() < 2) // (2 frames needed for boundary markers)
throw std::runtime_error ("minibatchframesourcemulti: utterances < 2 frames not supported");
foreach_column (t, feat)
{
foreach_index (k, frame)
frame[k] = feat(k,t);
pframes[m]->push_back (frame);
numframes++;
if (m==0)
boundaryflags.push_back ((t == 0) ? -1 : (t == feat.cols() -1) ? +1 : 0);
}
if (m==0)
framesaccum.push_back(numframes);
else
assert(numframes == framesaccum[i]);
assert (numframes == pframes[m]->size());
}
if (m==0)
assert (numframes == boundaryflags.size());
if (m==0) // after we get the key for this file, read all labels (only done for first feature)
{
if (!key.empty())
{
foreach_index (j, labels)
{
const auto & labseq = labels[j].find (key)->second; // (we already checked above that it exists)
foreach_index (i, labseq)
{
const auto & e = labseq[i];
if ((i > 0 && labseq[i-1].firstframe + labseq[i-1].numframes != e.firstframe) || (i == 0 && e.firstframe != 0))
throw std::runtime_error (msra::strfun::strprintf ("minibatchframesourcemulti: labels not in consecutive order MLF in label set: %S", key.c_str()));
for (size_t t = e.firstframe; t < e.firstframe + e.numframes; t++)
{
if (e.classid >= udim[j])
throw std::runtime_error (msra::strfun::strprintf ("minibatchframesourcemulti: class id exceeds model dimension in file %S", key.c_str()));
if (e.classid != (CLASSIDTYPE) e.classid)
throw std::runtime_error ("CLASSIDTYPE has too few bits");
classids[j].push_back ((CLASSIDTYPE) e.classid);
numclasses[j] = max (numclasses[j], (long unsigned int)(1u + e.classid));
}
}
if (vdim[m] == 0)
numframes = classids[j].size();
if (numframes != classids[j].size()) // TODO: remove this once we are confident
throw std::runtime_error (msra::strfun::strprintf ("minibatchframesourcemulti: label duration inconsistent with feature file in MLF label set: %S", key.c_str()));
assert (numframes == classids[j].size());
}
}
else
{
assert(classids.empty());
}
}
}
assert (vdim[m] == 0 || numframes == pframes[m]->size());
foreach_index(j, labels)
assert (labels[j].empty() || numframes == classids[j].size());
if (vdim[m] != 0 && numframes != pframes[m]->size()) // || (!labels.empty() && numframes != classids.size()))
throw std::runtime_error ("\nminibatchframesource: numframes variable screwup");
if (m==0)
{
foreach_index (j, numclasses)
fprintf (stderr, "\nminibatchframesourcemulti: read label set %d: %zu classes\n", j, numclasses[j]);
}
fprintf (stderr, "\nminibatchframesourcemulti: feature set %d: %zu frames read from %zu utterances\n", m, pframes[m]->size(), infiles[m].size());
if (notfound > 0)
{
fprintf (stderr, "minibatchframesourcemulti: %zu files out of %zu not found in label set\n", notfound, infiles[m].size());
if (notfound > infiles[m].size() / 2)
throw std::runtime_error ("minibatchframesourcemulti: too many files not found in label set--assuming broken configuration\n");
}
// notify frames source to switch from population to consumption mode
pframes[m]->no_more_push_back();
}
if (numframes == 0 && !mayhavenoframe)
throw std::runtime_error ("minibatchframesource: no input features given!");
// initialize randomizer
if (numframes > 0)
randomordering.resize (numframes, randomizationrange);
}
virtual ~minibatchframesourcemulti() {}
size_t totalframes() const {
assert (maxvdim == 0 || numframes == pframes[0]->size()); assert (!issupervised() || numframes == classids[0].size()); return numframes; }
bool issupervised() const { return !classids.empty(); }
void setverbosity(int newverbosity) { verbosity = newverbosity; }
// retrieve one minibatch
// Minibatches are deterministic pseudo-random samples. The entire corpus
// is repeated infinitely, but each repetition (a 'sweep') is randomized
// differently.
// This function allows to retrieve a mini-batch starting from any frame
// within this infinitely extended repetition. To the end, mini-batches are
// specified by start frame and #frames.
// This function returns the same data independent on #frames, i.e. the concept
// of the mini-batch is not defined in here, but on the caller side. The caller
// can retrieve the frames of a mini-batch in chunks that do not match the
// caller's definition of "mini-batch," e.g. bigger or smaller chunks.
// If a requested mini-batch spans a sweep boundary, then this function will
// not return samples after the sweep boundary. Instead, the returned frame
// set is shortened to not exceed the end of the sweep. The caller must make
// a separate second call to get the rest. In trainlayer(), the one
// sweep-boundary-spanning mini-batch will simply be shortened.
// This function is NOT thread-safe (due to caching of random sequence).
bool getbatch (const size_t globalts, const size_t framesrequested, std::vector<msra::dbn::matrix> & feat, std::vector<std::vector<size_t>> & uids,
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> & transcripts,
std::vector<shared_ptr<const latticesource::latticepair>> & latticepairs)
{
auto_timer timergetbatch;
bool readfromdisk;
size_t nreadfromdisk=0;
transcripts.clear(); // word-level transcripts not supported by frame source (aimed at MMI)
latticepairs.clear(); // neither are lattices
assert (totalframes() > 0);
const size_t sweep = globalts / totalframes(); // which sweep (this determines randomization)
const size_t ts = globalts % totalframes(); // start frame within the sweep
const size_t te = min (ts + framesrequested, totalframes()); // do not go beyond sweep boundary
assert (te > ts);
if (verbosity >= 2)
fprintf (stderr, "getbatch: frames [%zu..%zu] in sweep %zu\n", ts, te-1, sweep);
// get random sequence (each time index occurs exactly once)
// If the sweep changes, this will re-cache the sequence. We optimize for rare, monotonous sweep changes.
const auto & tmap = randomordering (sweep);
feat.resize(pframes.size());
uids.resize(classids.size());
foreach_index(i, feat)
{
size_t leftextent, rightextent;
// page in the needed range of frames
if (leftcontext[i] == 0 && rightcontext[i] == 0)
{
leftextent = rightextent = augmentationextent(pframes[i]->dim(), vdim[i]);
}
else
{
leftextent = leftcontext[i];
rightextent = rightcontext[i];
}
readfromdisk = pframes[i]->require (randomordering.bounds (max (ts, leftextent) - leftextent, te + 1 + rightextent));
// generate features and uids
feat[i].resize (vdim[i], te - ts); // note: special mode vdim == 0 means no features to be loaded
if (issupervised()) // empty means unsupervised training -> return empty uids
foreach_index(j, uids)
uids[j].resize (te - ts);
else
uids.clear();
for (size_t t = ts; t < te; t++)
{
size_t trand = tmap[t]; // the random-sequence sample point for this point in time
if (vdim[i] != 0)
{
auto v_t = feat[i].col(t-ts); // the vector to fill in
augmentneighbors (*pframes[i], boundaryflags, trand, leftextent, rightextent, v_t);
}
if (i==0){ // read labels for all outputs on first pass thru features. this guarantees they will be read if only one feature set but > 1 label set
if (issupervised())
foreach_index(j, uids)
uids[j][t-ts] = classids[j][trand];
}
}
timegetbatch = timergetbatch;
if (readfromdisk)
nreadfromdisk++;
}
(nreadfromdisk==feat.size()) ? readfromdisk = true : readfromdisk = false;
return readfromdisk;
}
bool getbatch (const size_t /*globalts*/, const size_t /*framesrequested*/, msra::dbn::matrix & /*feat*/, std::vector<size_t> & /*uids*/,
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> & /*transcripts*/,
std::vector<shared_ptr<const latticesource::latticepair>> & /*latticepairs*/)
{
// should never get here
throw runtime_error("minibatchframesourcemulti: getbatch() being called for single input feature and single output feature, should use minibatchframesource instead\n");
}
double gettimegetbatch () { return timegetbatch;}
// return first valid globalts to ask getbatch() for
// In frame mode, there is no constraint, i.e. it is 'globalts' itself.
/*implement*/ size_t firstvalidglobalts (const size_t globalts) { return globalts; }
/*implement*/ const std::vector<size_t> & unitcounts() const { throw logic_error ("unitcounts: not implemented for this feature source"); }
};
};};

Просмотреть файл

@ -1,89 +0,0 @@
//
// <copyright file="simple_checked_arrays.h" company="Microsoft">
// Copyright (c) Microsoft Corporation. All rights reserved.
// </copyright>
//
// simple_checked_arrays.h -- a simple wrapper around pointers used as arrays to allow bounds checking
//
#pragma once
#include <stddef.h> // for size_t
#include <assert.h>
// ---------------------------------------------------------------------------
// array_ref -- wraps a C pointer to an array together with its size.
//
// Called _ref because this is a reference to the array rather than the array
// itself (since it wraps a pointer). No need to pass an array_ref by reference.
//
// operator[] checks index bounds in Debug builds. size() is provided such
// that this class can be substituted for STL vector in many cases.
// ---------------------------------------------------------------------------
template<class _T> class array_ref
{
_T * data;
size_t n;
inline void check_index (size_t i) const { i; assert (i < n); }
inline void check_ptr() const { n; data; assert (n == 0 || data != NULL); }
public:
inline array_ref (_T * ptr, size_t size) throw() : data (ptr), n (size) { }
inline array_ref() throw() : data (NULL), n (0) { } // in case we have a vector of this
inline _T & operator[] (size_t i) throw() { check_index (i); return data[i]; }
inline const _T & operator[] (size_t i) const throw() { check_index (i); return data[i]; }
inline size_t size() const throw() { return n; }
inline _T * begin() { return data; }
inline _T * end() { return data + n; }
inline void resize (size_t sz) { sz; assert (n == sz); } // allow compatibility with some functions
// construct from other vector types
template<class _V> inline array_ref (_V & v) : data (v.size() > 0 ? &v[0] : NULL), n ((size_t) v.size()) { }
};
// ---------------------------------------------------------------------------
// const_array_ref -- same as array_ref for 'const' (read-only) pointers
// ---------------------------------------------------------------------------
template<class _T> class const_array_ref
{
const _T * data;
size_t n;
inline void check_index (size_t i) const { i; assert (i < n); }
inline void check_ptr() const { n; data; assert (n == 0 || data != NULL); }
public:
inline const_array_ref (const _T * ptr, size_t size) throw() : data (ptr), n (size) { }
inline const_array_ref() throw() : data (NULL), n (0) { } // in case we have a vector of this
inline const _T & operator[] (size_t i) const throw() { check_index (i); return data[i]; }
inline size_t size() const throw() { return n; }
inline const _T * begin() { return data; }
inline const _T * end() { return data + n; }
inline const _T & front() const throw() { check_index (0); return data[0];}
inline const _T & back() const throw() {check_index (0); return data[n-1];}
// construct from other vector types
template<class _V> inline const_array_ref (const _V & v) : data (v.size() > 0 ? &v[0] : NULL), n ((size_t) v.size()) { }
};
// ---------------------------------------------------------------------------
// hardcoded_array -- wraps a fixed-size C array together with its size.
//
// operator[] checks index bounds in Debug builds. size() is provided such
// that this class can be substituted for STL vector in many cases.
// Can be constructed with a size parameter--it will be checked against the
// hard-coded size.
// Can also be constructed with an initialization parameter (typ. 0).
// ---------------------------------------------------------------------------
template<class _T, int _N> class hardcoded_array
{
_T data[_N];
inline void check_index (size_t i) const { i; assert (i < _N); }
inline void check_size (size_t n) const { n; assert (n == _N); }
public:
inline hardcoded_array() throw() {}
inline hardcoded_array (size_t n) throw() { check_size (n); } // we can instantiate with a size parameter--just checks the size
inline hardcoded_array (size_t n, const _T & val) throw() { check_size (n); for (size_t i = 0; i < n; i++) data[i] = val; }
inline _T & operator[] (size_t i) throw() { check_index (i); return data[i]; }
inline const _T & operator[] (size_t i) const throw() { check_index (i); return data[i]; }
inline size_t size() const throw() { return _N; }
};

Просмотреть файл

@ -1,241 +0,0 @@
//
// <copyright file="simplesenonehmm.h" company="Microsoft">
// Copyright (c) Microsoft Corporation. All rights reserved.
// </copyright>
//
// latticearchive.h -- managing lattice archives
//
#pragma once
#include "basetypes.h"
#include "fileutil.h"
#include <vector>
#include <string>
#include <unordered_map>
#include <algorithm> // for find()
#include "simple_checked_arrays.h"
namespace msra { namespace asr {
// ===========================================================================
// simplesenonehmm -- simple senone-based CD-HMM
// ===========================================================================
class simplesenonehmm
{
public: // (TODO: better encapsulation)
static const size_t MAXSTATES = 3; // we use a fixed memory allocation since it's almost always 3 anyway
struct transP;
struct hmm
{
const char * name; // (this points into the key in the hash table to save memory)
struct transP * transP; // underlying transition matrix
unsigned char transPindex; // index of transP in struct transP
unsigned char numstates; // number of states
unsigned short senoneids[MAXSTATES]; // [0..numstates-1] senone indices
const char * getname() const { return name; } // (should be used for diagnostics only)
size_t getsenoneid (size_t i) const { if (i < numstates) return (size_t) senoneids[i]; throw std::logic_error ("getsenoneid: out of bounds access"); }
size_t getnumstates() const { return (size_t) numstates; }
unsigned char gettransPindex() const { return transPindex;}
const struct transP & gettransP() const { return *transP; }
bool operator< (const hmm & other) const
{
return memcmp (this, &other, sizeof (other)) < 0;
}
};
std::vector<hmm> hmms; // the set of HMMs
std::unordered_map<std::string,size_t> symmap; // [name] -> index into hmms[]
struct transP
{
private:
size_t numstates;
float loga[MAXSTATES+1][MAXSTATES+1];
void check (int from, size_t to) const { if (from < -1 || from >= (int) numstates || to > numstates) throw std::logic_error ("transP: index out of bounds"); }
public:
void resize (size_t n) { if (n > MAXSTATES) throw std::runtime_error ("resize: requested transP that exceeds MAXSTATES"); numstates = n; }
size_t getnumstates() const { return numstates; }
// from = -1 and to = numstates are allowed, but we also allow 'from' to be size_t to avoid silly typecasts
float & operator() (int from, size_t to) { check (from, to); return loga[from+1][to]; } // from >= -1
const float & operator() (int from, size_t to) const { check (from, to); return loga[from+1][to]; } // from >= -1
const float & operator() (size_t from, size_t to) const { check ((int)from, to); return loga[from+1][to]; } // from >= 0
transP() : numstates (0) {}
};
std::vector<transP> transPs; // the transition matrices --TODO: finish this
std::hash_map<std::string,size_t> transPmap; // [transPname] -> index into transPs[]
public:
// get an hmm by index
const hmm & gethmm (size_t i) const { return hmms[i]; }
// get an hmm by name
size_t gethmmid (const string & name) const
{
auto iter = symmap.find (name);
if (iter == symmap.end())
throw std::logic_error ("gethmm: unknown unit name: " + name);
return iter->second;
}
// diagnostics: map state id to senone name
std::vector<std::string> statenames;
const char * getsenonename (size_t senoneid) const { return statenames[senoneid].c_str(); }
// inverse lookup, for re-scoring the ground-truth path for sequence training
// This may be ambiguous, but we know that for current setup, that's only the case for /sil/ and /sp/.
std::vector<int> senoneid2transPindex; // or -1 if ambiguous
std::vector<int> senoneid2stateindex; // 0..2, or -1 if ambiguous
// construct from model files
simplesenonehmm (const std::wstring & cdphonetyingpath, const std::wstring & statelistpath, const std::wstring & transPpath)
{
if (cdphonetyingpath.empty()) // no tying info specified --just leave an empty object
return;
fprintf (stderr, "simplesenonehmm: reading '%S', '%S', '%S'\n", cdphonetyingpath.c_str(), statelistpath.c_str(), transPpath.c_str());
// read the state list
vector<char> textbuffer;
auto readstatenames = msra::files::fgetfilelines (statelistpath, textbuffer);
foreach_index (s, readstatenames)
statenames.push_back (readstatenames[s]);
std::unordered_map<std::string,size_t> statemap; // [name] -> index
statemap.rehash (readstatenames.size());
foreach_index (i, readstatenames)
statemap[readstatenames[i]] = i;
// TRANSPNAME NUMSTATES (ROW_from[to])+
msra::strfun::tokenizer toks (" \t", 5);
auto transPlines = msra::files::fgetfilelines (transPpath, textbuffer);
transPs.resize (transPlines.size());
string key; key.reserve (100);
foreach_index (i, transPlines)
{
toks = transPlines[i];
if (toks.size() < 3)
throw std::runtime_error ("simplesenonehmm: too few tokens in transP line: " + string (transPlines[i]));
key = toks[0]; // transPname --using existing object to avoid malloc
transPmap[key] = i;
size_t numstates = msra::strfun::toint (toks[1]);
if (numstates == 0)
throw std::runtime_error ("simplesenonehmm: invalid numstates: " + string (transPlines[i]));
auto & transP = transPs[i];
transP.resize (numstates);
size_t k = 2; // index into tokens; transP values start at toks[2]
for (int from = -1; from < (int) numstates; from++) for (size_t to = 0; to <= numstates; to++)
{
if (k >= toks.size())
throw std::runtime_error ("simplesenonehmm: not enough tokens on transP line: " + string (transPlines[i]));
const char * sval = toks[k++];
const double aij = msra::strfun::todouble (sval);
if (aij > 1e-10) // non-0
transP(from,to) = logf ((float) aij); // we store log probs
else
transP(from,to) = -1e30f;
}
if (toks.size() > k)
throw std::runtime_error ("simplesenonehmm: unexpected garbage at endof transP line: " + string (transPlines[i]));
}
// allocate inverse lookup
senoneid2transPindex.resize (readstatenames.size(), -2);
senoneid2stateindex.resize (readstatenames.size(), -2);
// read the cd-phone tying info
// HMMNAME TRANSPNAME SENONENAME+
auto lines = msra::files::fgetfilelines (cdphonetyingpath, textbuffer);
hmms.reserve (lines.size());
symmap.rehash (lines.size());
// two tables: (1) name -> HMM; (2) HMM -> HMM index (uniq'ed)
map<string,hmm> name2hmm; // [name] -> unique HMM struct (without name)
map<hmm,size_t> hmm2index; // [unique HMM struct] -> hmm index, hmms[i] contains full hmm
foreach_index (i, lines)
{
toks = lines[i];
if (toks.size() < 3)
throw std::runtime_error ("simplesenonehmm: too few tokens in line: " + string (lines[i]));
const char * hmmname = toks[0];
const char * transPname = toks[1];
// build the HMM structure
hmm hmm;
hmm.name = NULL; // for use as key in hash tables, we keep this NULL
// get the transP pointer
// TODO: this becomes a hard lookup with failure
key = transPname; // (reuse existing memory)
auto iter = transPmap.find (key);
if (iter == transPmap.end())
throw std::runtime_error ("simplesenonehmm: unknown transP name: " + string (lines[i]));
size_t transPindex = iter->second;
hmm.transPindex = (unsigned char) transPindex;
hmm.transP = &transPs[transPindex];
if (hmm.transPindex != transPindex)
throw std::runtime_error ("simplesenonehmm: numeric overflow for transPindex field");
// get the senones
hmm.numstates = (unsigned char) (toks.size() - 2); // remaining tokens
if (hmm.numstates != transPs[transPindex].getnumstates())
throw std::runtime_error ("simplesenonehmm: number of states mismatches that of transP: " + string (lines[i]));
if (hmm.numstates > _countof (hmm.senoneids))
throw std::runtime_error ("simplesenonehmm: hmm.senoneids[MAXSTATES] is too small in line: " + string (lines[i]));
for (size_t s = 0; s < hmm.numstates; s++)
{
const char * senonename = toks[s+2];
key = senonename; // (reuse existing memory)
auto iter = statemap.find (key);
if (iter == statemap.end())
throw std::runtime_error ("simplesenonehmm: unrecognized senone name in line: " + string (lines[i]));
hmm.senoneids[s] = (unsigned short) iter->second;
if (hmm.getsenoneid(s) != iter->second)
throw std::runtime_error ("simplesenonehmm: not enough bits to store senone index in line: " + string (lines[i]));
// inverse lookup
if (senoneid2transPindex[hmm.senoneids[s]] == -2) // no value yet
senoneid2transPindex[hmm.senoneids[s]] = hmm.transPindex;
else if (senoneid2transPindex[hmm.senoneids[s]] != hmm.transPindex)
senoneid2transPindex[hmm.senoneids[s]] = -1; // multiple inconsistent values
if (senoneid2stateindex[hmm.senoneids[s]] == -2)
senoneid2stateindex[hmm.senoneids[s]] = (int) s;
else if (senoneid2stateindex[hmm.senoneids[s]] != (int) s)
senoneid2stateindex[hmm.senoneids[s]] = -1;
}
for (size_t s = hmm.numstates; s < _countof (hmm.senoneids); s++) // clear out the rest if needed
hmm.senoneids[s] = USHRT_MAX;
// add to name-to-HMM hash
auto ir = name2hmm.insert (std::make_pair (hmmname, hmm)); // insert into hash table
if (!ir.second) // not inserted
throw std::runtime_error ("simplesenonehmm: duplicate unit name in line: " + string (lines[i]));
// add to hmm-to-index hash
// and update the actual lookup table
size_t hmmindex = hmms.size(); // (assume it's a new entry)
auto is = hmm2index.insert (std::make_pair (hmm, hmmindex));
if (is.second) // was indeed inserted: add to hmms[]
{
// insert first, as this copies the name; we can then point to it
auto it = symmap.insert (std::make_pair (hmmname, hmmindex)); // insert into hash table
hmm.name = it.first->first.c_str(); // only use first name if multiple (the name is informative only anyway)
hmms.push_back (hmm);
}
else // not inserted
{
hmmindex = is.first->second; // use existing value
symmap.insert (std::make_pair (hmmname, hmmindex)); // insert into hash table
}
}
fprintf (stderr, "simplesenonehmm: %zu units with %zu unique HMMs, %zu tied states, and %zu trans matrices read\n",
symmap.size(), hmms.size(), statemap.size(), transPs.size());
}
// exposed so we can pass it to the lattice reader, which maps the symbol ids for us
const std::unordered_map<std::string,size_t> & getsymmap() const { return symmap; }
// inverse lookup --for scoring the ground-truth
// Note: /sil/ and /sp/ will be ambiguous, so need to handle them as a special case.
int senonetransP (size_t senoneid) const { return senoneid2transPindex[senoneid]; }
int senonestate (size_t senoneid) const { return senoneid2stateindex[senoneid]; }
const size_t getnumsenone () const {return senoneid2stateindex.size(); }
const bool statebelongstohmm (const size_t senoneid, const hmm & hmm) const // reutrn true if one of the states of this hmm == senoneid
{
size_t numstates = hmm.getnumstates();
for (size_t i = 0; i < numstates; i++)
if (hmm.senoneids[i] == senoneid)
return true;
return false;
}
};
};};

Просмотреть файл

@ -1,152 +0,0 @@
//
// <copyright file="simplethread.h" company="Microsoft">
// Copyright (c) Microsoft Corporation. All rights reserved.
// </copyright>
//
// simplethread.h -- a simple thread implementation
//
#pragma once
#include "basetypes.h"
#ifdef _WIN32
#include <process.h> // for _beginthread()
#endif
namespace msra { namespace util {
// ---------------------------------------------------------------------------
// signallingevent -- wrapper around Windows events
// ---------------------------------------------------------------------------
class signallingevent // TODO: should this go into basetypes.h?
{
HANDLE h;
public:
signallingevent (bool initialstate = true)
{
h = ::CreateEvent (NULL, FALSE/*manual reset*/, initialstate ? TRUE : FALSE, NULL);
if (h == NULL)
throw std::runtime_error ("signallingevent: CreateEvent() failed");
}
~signallingevent() { ::CloseHandle (h); }
void wait() { if (::WaitForSingleObject (h, INFINITE) != WAIT_OBJECT_0) throw std::runtime_error ("wait: WaitForSingleObject() unexpectedly failed"); }
void flag() { if (::SetEvent (h) == 0) throw std::runtime_error ("flag: SetEvent() unexpectedly failed"); }
};
// ---------------------------------------------------------------------------
// simplethread -- simple thread wrapper
// ---------------------------------------------------------------------------
class simplethread : CCritSec
{
std::shared_ptr<std::exception> badallocexceptionptr; // in case we fail to copy the exception
std::shared_ptr<std::exception> exceptionptr; // if non-NULL, then thread failed with exception
// wrapper around passing the functor
signallingevent startsignal;
const void * functorptr;
template<typename FUNCTION> static unsigned int __stdcall staticthreadproc (void * usv)
{
simplethread * us = (simplethread*) usv;
const FUNCTION body = *(const FUNCTION *) us->functorptr;
us->startsignal.flag();
us->threadproc (body);
return 0;
}
template<typename FUNCTION> void threadproc (const FUNCTION & body)
{
try
{
body(); // execute the function
}
catch (const std::exception & e)
{
fail (e);
}
catch (...) // we do not catch anything that is not based on std::exception
{
fprintf (stderr, "simplethread: thread proc failed with unexpected unknown exception, which is not allowed. Terminating\n");
fflush (stderr); // (needed?)
abort(); // should never happen
}
}
HANDLE threadhandle;
public:
template<typename FUNCTION> simplethread (const FUNCTION & body) : badallocexceptionptr (new std::bad_alloc()), functorptr (&body), startsignal (false)
{
unsigned int threadid;
uintptr_t rc = _beginthreadex (NULL/*security*/, 0/*stack*/, staticthreadproc<FUNCTION>, this, CREATE_SUSPENDED, &threadid);
if (rc == 0)
throw std::runtime_error ("simplethread: _beginthreadex() failed");
threadhandle = OpenThread (THREAD_ALL_ACCESS, FALSE, threadid);
if (threadhandle == NULL)
throw std::logic_error ("simplethread: _beginthreadex() unexpectedly did not return valid thread id"); // BUGBUG: leaking something
DWORD rc1 = ::ResumeThread (threadhandle);
if (rc1 == (DWORD) -1)
{
::TerminateThread (threadhandle, 0);
::CloseHandle (threadhandle);
throw std::logic_error ("simplethread: ResumeThread() failed unexpectedly");
}
try
{
startsignal.wait(); // wait until functor has been copied
}
catch (...)
{
::TerminateThread (threadhandle, 0);
::CloseHandle (threadhandle);
throw;
}
}
// check if the thread is still alive and without error
void check()
{
CAutoLock lock (*this);
// pass on a pending exception
if (exceptionptr)
throw *exceptionptr.get();
// the thread going away without error is also unexpected at this point
if (wait (0)) // (0 means don't block, so OK to call inside lock)
throw std::runtime_error ("check: thread terminated unexpectedly");
}
bool wait (DWORD dwMilliseconds = INFINITE)
{
DWORD rc = ::WaitForSingleObject (threadhandle, dwMilliseconds);
if (rc == WAIT_TIMEOUT)
return false;
else if (rc == WAIT_OBJECT_0)
return true;
else
throw std::runtime_error ("wait: WaitForSingleObject() failed unexpectedly");
}
// thread itself can set the failure condition, e.g. before it signals some other thread to pick it up
void fail (const std::exception & e)
{
// exception: remember it --this will remove the type info :(
CAutoLock lock (*this);
try // copy the exception--this may fail if we are out of memory
{
exceptionptr.reset (new std::runtime_error (e.what()));
}
catch (...) // failed to alloc: fall back to bad_alloc, which is most likely the cause in such situation
{
exceptionptr = badallocexceptionptr;
}
}
//void join()
//{
// check();
// wait();
// check_for_exception(); // (check() not sufficient because it would fail since thread is gone)
//}
~simplethread() throw()
{
// wait until it shuts down
try { wait(); }
catch (...) { ::TerminateThread (threadhandle, 0); }
// close the handle
::CloseHandle (threadhandle);
}
};
};};

Просмотреть файл

@ -1,123 +0,0 @@
//
// <copyright file="ssefloat4.h" company="Microsoft">
// Copyright (c) Microsoft Corporation. All rights reserved.
// </copyright>
//
// ssematrix.h -- matrix with SSE-accelerated operations
//
#pragma once
#ifdef _WIN32
#include <intrin.h> // for intrinsics
#endif
#ifdef __unix__
#include <x86intrin.h>
#endif
namespace msra { namespace math {
// ===========================================================================
// float4 -- wrapper around the rather ugly SSE intrinsics for float[4]
//
// Do not use the intrinsics outside anymore; instead add all you need into this class.
//
// MSDN links:
// basic: http://msdn.microsoft.com/en-us/library/x5c07e2a%28v=VS.80%29.aspx
// load/store: (add this)
// newer ones: (seems no single list available)
// ===========================================================================
class float4
{
__m128 v; // value
private:
// return the low 'float'
float f0() const { float f; _mm_store_ss (&f, v); return f; }
// construct from a __m128, assuming it is a f32 vector (needed for directly returning __m128 below)
float4 (const __m128 & v) : v (v) {}
// return as a __m128 --should this be a reference?
operator __m128() const { return v; }
// assign a __m128 (needed for using nested float4 objects inside this class, e.g. sum())
float4 & operator= (const __m128 & other) { v = other; return *this; }
public:
float4() {} // uninitialized
float4 (const float4 & f4) : v (f4.v) {}
float4 & operator= (const float4 & other) { v = other.v; return *this; }
// construct from a single float, copy to all components
float4 (float f) : v (_mm_load1_ps (&f)) {}
//float4 (float f) : v (_mm_set_ss (f)) {} // code seems more complex than _mm_load1_ps()
// basic math
float4 operator-() const { return _mm_sub_ps (_mm_setzero_ps(), v); } // UNTESTED; setzero is a composite
float4 operator& (const float4 & other) const { return _mm_and_ps (v, other); }
float4 operator| (const float4 & other) const { return _mm_or_ps (v, other); }
float4 operator+ (const float4 & other) const { return _mm_add_ps (v, other); }
float4 operator- (const float4 & other) const { return _mm_sub_ps (v, other); }
float4 operator* (const float4 & other) const { return _mm_mul_ps (v, other); }
float4 operator/ (const float4 & other) const { return _mm_div_ps (v, other); }
float4 & operator&= (const float4 & other) { v = _mm_and_ps (v, other); return *this; }
float4 & operator|= (const float4 & other) { v = _mm_or_ps (v, other); return *this; }
float4 & operator+= (const float4 & other) { v = _mm_add_ps (v, other); return *this; }
float4 & operator-= (const float4 & other) { v = _mm_sub_ps (v, other); return *this; }
float4 & operator*= (const float4 & other) { v = _mm_mul_ps (v, other); return *this; }
float4 & operator/= (const float4 & other) { v = _mm_div_ps (v, other); return *this; }
float4 operator>= (const float4 & other) const { return _mm_cmpge_ps (v, other); }
float4 operator<= (const float4 & other) const { return _mm_cmple_ps (v, other); }
// not yet implemented binary arithmetic ops: sqrt, rcp (reciprocal), rqsrt, min, max
// other goodies I came across (intrin.h):
// - _mm_prefetch
// - _mm_stream_ps --store without polluting cache
// - unknown: _mm_addsub_ps, _mm_hsub_ps, _mm_movehdup_ps, _mm_moveldup_ps, _mm_blend_ps, _mm_blendv_ps, _mm_insert_ps, _mm_extract_ps, _mm_round_ps
// - _mm_dp_ps dot product! http://msdn.microsoft.com/en-us/library/bb514054.aspx
// Not so interesting for long vectors, we get better numerical precision with parallel adds and hadd at the end
// prefetch a float4 from an address
static void prefetch (const float4 * p) { _mm_prefetch ((const char *) const_cast<float4 *> (p), _MM_HINT_T0); }
// transpose a 4x4 matrix
// Passing input as const ref to ensure aligned-ness
static void transpose (const float4 & col0, const float4 & col1, const float4 & col2, const float4 & col3,
float4 & row0, float4 & row1, float4 & row2, float4 & row3)
{ // note: the temp variable here gets completely eliminated by optimization
float4 m0 = col0; float4 m1 = col1; float4 m2 = col2; float4 m3 = col3;
_MM_TRANSPOSE4_PS (m0, m1, m2, m3); // 8 instructions for 16 elements
row0 = m0; row1 = m1; row2 = m2; row3 = m3;
}
// save a float4 to RAM bypassing the cache ('without polluting the cache')
void storewithoutcache (float4 & r4) const
{
//_mm_stream_ps ((float*) &r4, v);
r4 = v;
}
#if 0
// save a float4 to RAM bypassing the cache ('without polluting the cache')
void storewithoutcache (float4 * p4) const
{
//_mm_stream_ps ((float*) p4, v);
*p4 = v;
}
// save a float to RAM bypassing the cache ('without polluting the cache')
void storewithoutcache (float & r) const
{
_mm_stream_ss (&r, v);
}
#endif
// return the horizontal sum of all 4 components
// ... return float4, use another mechanism to store the low word
float sum() const { float4 hsum = _mm_hadd_ps (v, v); hsum = _mm_hadd_ps (hsum, hsum); return hsum.f0(); }
// please add anything else you might need HERE
};
};};

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -1,156 +1,156 @@
<?xml version="1.0" encoding="utf-8"?>
<Project DefaultTargets="Build" ToolsVersion="12.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup Label="ProjectConfigurations">
<ProjectConfiguration Include="Debug|x64">
<Configuration>Debug</Configuration>
<Platform>x64</Platform>
</ProjectConfiguration>
<ProjectConfiguration Include="Release|x64">
<Configuration>Release</Configuration>
<Platform>x64</Platform>
</ProjectConfiguration>
</ItemGroup>
<PropertyGroup Label="Globals">
<ProjectGuid>{9A2F2441-5972-4EA8-9215-4119FCE0FB68}</ProjectGuid>
<SccProjectName>
</SccProjectName>
<SccAuxPath>
</SccAuxPath>
<SccLocalPath>
</SccLocalPath>
<SccProvider>
</SccProvider>
<Keyword>Win32Proj</Keyword>
<RootNamespace>UCIReader</RootNamespace>
<ProjectName>LMSequenceReader</ProjectName>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
<ConfigurationType>DynamicLibrary</ConfigurationType>
<UseDebugLibraries>true</UseDebugLibraries>
<PlatformToolset>v120</PlatformToolset>
<CharacterSet>Unicode</CharacterSet>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
<ConfigurationType>DynamicLibrary</ConfigurationType>
<UseDebugLibraries>false</UseDebugLibraries>
<PlatformToolset>v120</PlatformToolset>
<WholeProgramOptimization>true</WholeProgramOptimization>
<CharacterSet>Unicode</CharacterSet>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
<ImportGroup Label="ExtensionSettings">
</ImportGroup>
<ImportGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="PropertySheets">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<ImportGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="PropertySheets">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<PropertyGroup Label="UserMacros" />
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<LinkIncremental>true</LinkIncremental>
<IncludePath>..\..\common\include;..\..\math\math;$(VCInstallDir)include;$(VCInstallDir)atlmfc\include;$(WindowsSDK_IncludePath);</IncludePath>
<LibraryPath>$(SolutionDir)$(Platform)\$(Configuration);$(VCInstallDir)lib\amd64;$(VCInstallDir)atlmfc\lib\amd64;$(WindowsSDK_LibraryPath_x64);</LibraryPath>
<IntDir>$(Platform)\$(Configuration)\$(ProjectName)\</IntDir>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<LinkIncremental>false</LinkIncremental>
<IncludePath>..\..\common\include;..\..\math\math;$(VCInstallDir)include;$(VCInstallDir)atlmfc\include;$(WindowsSDK_IncludePath);</IncludePath>
<LibraryPath>$(SolutionDir)$(Platform)\$(Configuration);$(VCInstallDir)lib\amd64;$(VCInstallDir)atlmfc\lib\amd64;$(WindowsSDK_LibraryPath_x64);</LibraryPath>
<IntDir>$(Platform)\$(Configuration)\$(ProjectName)\</IntDir>
</PropertyGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<ClCompile>
<PrecompiledHeader>Use</PrecompiledHeader>
<WarningLevel>Level4</WarningLevel>
<Optimization>Disabled</Optimization>
<PreprocessorDefinitions>WIN32;_DEBUG;_WINDOWS;_USRDLL;UCIREADER_EXPORTS;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<SDLCheck>true</SDLCheck>
<AdditionalIncludeDirectories>..\..\common\include;..\..\math\math</AdditionalIncludeDirectories>
<TreatWarningAsError>true</TreatWarningAsError>
</ClCompile>
<Link>
<SubSystem>Windows</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
<AdditionalDependencies>CNTKMath.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
<AdditionalLibraryDirectories>$(SolutionDir)$(Platform)\$(Configuration)\;..\..\math\$(Platform)\$(Configuration);..\$(Platform)\$(Configuration)</AdditionalLibraryDirectories>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<ClCompile>
<WarningLevel>Level4</WarningLevel>
<PrecompiledHeader>Use</PrecompiledHeader>
<Optimization>MaxSpeed</Optimization>
<FunctionLevelLinking>true</FunctionLevelLinking>
<IntrinsicFunctions>true</IntrinsicFunctions>
<PreprocessorDefinitions>WIN32;NDEBUG;_WINDOWS;_USRDLL;UCIREADER_EXPORTS;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<SDLCheck>true</SDLCheck>
<AdditionalIncludeDirectories>..\..\common\include;..\..\math\math</AdditionalIncludeDirectories>
<OpenMPSupport>false</OpenMPSupport>
<AdditionalOptions>/d2Zi+ %(AdditionalOptions)</AdditionalOptions>
<TreatWarningAsError>true</TreatWarningAsError>
</ClCompile>
<Link>
<SubSystem>Windows</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
<AdditionalDependencies>CNTKmath.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
<AdditionalLibraryDirectories>..\..\math\$(Platform)\$(Configuration);$(SolutionDir)$(Platform)\$(Configuration)\</AdditionalLibraryDirectories>
<Profile>true</Profile>
</Link>
</ItemDefinitionGroup>
<ItemGroup>
<ClInclude Include="..\..\Common\Include\basetypes.h" />
<ClInclude Include="..\..\Common\Include\DataReader.h" />
<ClInclude Include="..\..\Common\Include\DataWriter.h" />
<ClInclude Include="..\..\Common\Include\File.h" />
<ClInclude Include="..\..\Common\Include\fileutil.h" />
<ClInclude Include="minibatchsourcehelpers.h" />
<ClInclude Include="SequenceWriter.h" />
<ClInclude Include="stdafx.h" />
<ClInclude Include="targetver.h" />
<ClInclude Include="SequenceReader.h" />
<ClInclude Include="SequenceParser.h" />
</ItemGroup>
<ItemGroup>
<ClCompile Include="..\..\Common\ConfigFile.cpp">
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">NotUsing</PrecompiledHeader>
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">NotUsing</PrecompiledHeader>
</ClCompile>
<ClCompile Include="..\..\Common\DataReader.cpp" />
<ClCompile Include="..\..\Common\DataWriter.cpp" />
<ClCompile Include="..\..\Common\File.cpp">
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">NotUsing</PrecompiledHeader>
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">NotUsing</PrecompiledHeader>
</ClCompile>
<ClCompile Include="..\..\Common\fileutil.cpp">
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">NotUsing</PrecompiledHeader>
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">NotUsing</PrecompiledHeader>
</ClCompile>
<ClCompile Include="Exports.cpp" />
<ClCompile Include="dllmain.cpp">
<CompileAsManaged Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</CompileAsManaged>
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
</PrecompiledHeader>
<CompileAsManaged Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</CompileAsManaged>
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
</PrecompiledHeader>
</ClCompile>
<ClCompile Include="SequenceWriter.cpp" />
<ClCompile Include="stdafx.cpp">
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">Create</PrecompiledHeader>
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">Create</PrecompiledHeader>
</ClCompile>
<ClCompile Include="SequenceReader.cpp" />
<ClCompile Include="SequenceParser.cpp" />
</ItemGroup>
<ItemGroup>
<Text Include="SentenceTest.txt" />
<Text Include="SequenceTest.txt" />
</ItemGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets">
</ImportGroup>
<?xml version="1.0" encoding="utf-8"?>
<Project DefaultTargets="Build" ToolsVersion="12.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup Label="ProjectConfigurations">
<ProjectConfiguration Include="Debug|x64">
<Configuration>Debug</Configuration>
<Platform>x64</Platform>
</ProjectConfiguration>
<ProjectConfiguration Include="Release|x64">
<Configuration>Release</Configuration>
<Platform>x64</Platform>
</ProjectConfiguration>
</ItemGroup>
<PropertyGroup Label="Globals">
<ProjectGuid>{9A2F2441-5972-4EA8-9215-4119FCE0FB68}</ProjectGuid>
<SccProjectName>
</SccProjectName>
<SccAuxPath>
</SccAuxPath>
<SccLocalPath>
</SccLocalPath>
<SccProvider>
</SccProvider>
<Keyword>Win32Proj</Keyword>
<RootNamespace>UCIReader</RootNamespace>
<ProjectName>LMSequenceReader</ProjectName>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
<ConfigurationType>DynamicLibrary</ConfigurationType>
<UseDebugLibraries>true</UseDebugLibraries>
<PlatformToolset>v120</PlatformToolset>
<CharacterSet>Unicode</CharacterSet>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
<ConfigurationType>DynamicLibrary</ConfigurationType>
<UseDebugLibraries>false</UseDebugLibraries>
<PlatformToolset>v120</PlatformToolset>
<WholeProgramOptimization>true</WholeProgramOptimization>
<CharacterSet>Unicode</CharacterSet>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
<ImportGroup Label="ExtensionSettings">
</ImportGroup>
<ImportGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="PropertySheets">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<ImportGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="PropertySheets">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<PropertyGroup Label="UserMacros" />
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<LinkIncremental>true</LinkIncremental>
<IncludePath>..\..\common\include;..\..\math\math;$(VCInstallDir)include;$(VCInstallDir)atlmfc\include;$(WindowsSDK_IncludePath);</IncludePath>
<LibraryPath>$(SolutionDir)$(Platform)\$(Configuration);$(VCInstallDir)lib\amd64;$(VCInstallDir)atlmfc\lib\amd64;$(WindowsSDK_LibraryPath_x64);</LibraryPath>
<IntDir>$(Platform)\$(Configuration)\$(ProjectName)\</IntDir>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<LinkIncremental>false</LinkIncremental>
<IncludePath>..\..\common\include;..\..\math\math;$(VCInstallDir)include;$(VCInstallDir)atlmfc\include;$(WindowsSDK_IncludePath);</IncludePath>
<LibraryPath>$(SolutionDir)$(Platform)\$(Configuration);$(VCInstallDir)lib\amd64;$(VCInstallDir)atlmfc\lib\amd64;$(WindowsSDK_LibraryPath_x64);</LibraryPath>
<IntDir>$(Platform)\$(Configuration)\$(ProjectName)\</IntDir>
</PropertyGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<ClCompile>
<PrecompiledHeader>Use</PrecompiledHeader>
<WarningLevel>Level4</WarningLevel>
<Optimization>Disabled</Optimization>
<PreprocessorDefinitions>WIN32;_DEBUG;_WINDOWS;_USRDLL;UCIREADER_EXPORTS;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<SDLCheck>true</SDLCheck>
<AdditionalIncludeDirectories>..\..\common\include;..\..\math\math</AdditionalIncludeDirectories>
<TreatWarningAsError>true</TreatWarningAsError>
</ClCompile>
<Link>
<SubSystem>Windows</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
<AdditionalDependencies>CNTKMathDll.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
<AdditionalLibraryDirectories>$(SolutionDir)$(Platform)\$(Configuration)\;..\..\math\$(Platform)\$(Configuration);..\$(Platform)\$(Configuration)</AdditionalLibraryDirectories>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<ClCompile>
<WarningLevel>Level4</WarningLevel>
<PrecompiledHeader>Use</PrecompiledHeader>
<Optimization>MaxSpeed</Optimization>
<FunctionLevelLinking>true</FunctionLevelLinking>
<IntrinsicFunctions>true</IntrinsicFunctions>
<PreprocessorDefinitions>WIN32;NDEBUG;_WINDOWS;_USRDLL;UCIREADER_EXPORTS;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<SDLCheck>true</SDLCheck>
<AdditionalIncludeDirectories>..\..\common\include;..\..\math\math</AdditionalIncludeDirectories>
<OpenMPSupport>false</OpenMPSupport>
<AdditionalOptions>/d2Zi+ %(AdditionalOptions)</AdditionalOptions>
<TreatWarningAsError>true</TreatWarningAsError>
</ClCompile>
<Link>
<SubSystem>Windows</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
<AdditionalDependencies>CNTKMathDll.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
<AdditionalLibraryDirectories>..\..\math\$(Platform)\$(Configuration);$(SolutionDir)$(Platform)\$(Configuration)\</AdditionalLibraryDirectories>
<Profile>true</Profile>
</Link>
</ItemDefinitionGroup>
<ItemGroup>
<ClInclude Include="..\..\Common\Include\basetypes.h" />
<ClInclude Include="..\..\Common\Include\DataReader.h" />
<ClInclude Include="..\..\Common\Include\DataWriter.h" />
<ClInclude Include="..\..\Common\Include\File.h" />
<ClInclude Include="..\..\Common\Include\fileutil.h" />
<ClInclude Include="minibatchsourcehelpers.h" />
<ClInclude Include="SequenceWriter.h" />
<ClInclude Include="stdafx.h" />
<ClInclude Include="targetver.h" />
<ClInclude Include="SequenceReader.h" />
<ClInclude Include="SequenceParser.h" />
</ItemGroup>
<ItemGroup>
<ClCompile Include="..\..\Common\ConfigFile.cpp">
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">NotUsing</PrecompiledHeader>
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">NotUsing</PrecompiledHeader>
</ClCompile>
<ClCompile Include="..\..\Common\DataReader.cpp" />
<ClCompile Include="..\..\Common\DataWriter.cpp" />
<ClCompile Include="..\..\Common\File.cpp">
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">NotUsing</PrecompiledHeader>
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">NotUsing</PrecompiledHeader>
</ClCompile>
<ClCompile Include="..\..\Common\fileutil.cpp">
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">NotUsing</PrecompiledHeader>
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">NotUsing</PrecompiledHeader>
</ClCompile>
<ClCompile Include="Exports.cpp" />
<ClCompile Include="dllmain.cpp">
<CompileAsManaged Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</CompileAsManaged>
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
</PrecompiledHeader>
<CompileAsManaged Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</CompileAsManaged>
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
</PrecompiledHeader>
</ClCompile>
<ClCompile Include="SequenceWriter.cpp" />
<ClCompile Include="stdafx.cpp">
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">Create</PrecompiledHeader>
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">Create</PrecompiledHeader>
</ClCompile>
<ClCompile Include="SequenceReader.cpp" />
<ClCompile Include="SequenceParser.cpp" />
</ItemGroup>
<ItemGroup>
<Text Include="SentenceTest.txt" />
<Text Include="SequenceTest.txt" />
</ItemGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets">
</ImportGroup>
</Project>

Просмотреть файл

@ -2101,7 +2101,7 @@ void BatchSequenceReader<ElemType>::GetLabelOutput(std::map < std::wstring,
}
template<class ElemType>
void BatchSequenceReader<ElemType>::SetSentenceSegBatch(Matrix<ElemType>& sentenceBegin, vector<MinibatchPackingFlag>& minibatchPackingFlag)
void BatchSequenceReader<ElemType>::SetSentenceSegBatch(Matrix<float>& sentenceBegin, vector<MinibatchPackingFlag>& minibatchPackingFlag)
{
DEVICEID_TYPE device = mtSentenceBegin.GetDeviceId();
mtSentenceBegin.TransferFromDeviceToDevice(device, sentenceBegin.GetDeviceId(), true);

Просмотреть файл

@ -76,7 +76,7 @@ public:
double logprob(int i) const { if (uniform_sampling) return uniform_log_prob; else return m_log_prob[i]; }
template <typename Engine>
int sample(Engine &eng) const
int sample(Engine &eng)
{
int m = unif_int(eng);
if (uniform_sampling)
@ -353,7 +353,7 @@ private:
bool mSentenceEnd;
bool mSentenceBegin;
Matrix<ElemType> mtSentenceBegin;
Matrix<float> mtSentenceBegin;
vector<MinibatchPackingFlag> m_minibatchPackingFlag;
public:
@ -396,7 +396,7 @@ public:
size_t NumberSlicesInEachRecurrentIter();
void SetSentenceSegBatch(std::vector<size_t> &sentenceEnd);
void SetSentenceSegBatch(Matrix<ElemType>& sentenceBegin, vector<MinibatchPackingFlag>& minibatchPackingFlag);
void SetSentenceSegBatch(Matrix<float>& sentenceBegin, vector<MinibatchPackingFlag>& minibatchPackingFlag);
int GetSentenceEndIdFromOutputLabel();

Просмотреть файл

@ -984,7 +984,7 @@ size_t BatchLUSequenceReader<ElemType>::GetLabelOutput(std::map<std::wstring,
}
template<class ElemType>
void BatchLUSequenceReader<ElemType>::SetSentenceSegBatch(Matrix<ElemType>& sentenceBegin, vector<MinibatchPackingFlag>& minibatchPackingFlag)
void BatchLUSequenceReader<ElemType>::SetSentenceSegBatch(Matrix<float>& sentenceBegin, vector<MinibatchPackingFlag>& minibatchPackingFlag)
{
DEVICEID_TYPE device = mtSentenceBegin.GetDeviceId();
mtSentenceBegin.TransferFromDeviceToDevice(device, sentenceBegin.GetDeviceId(), true);
@ -1291,7 +1291,7 @@ void MultiIOBatchLUSequenceReader<ElemType>::StartMinibatchLoop(size_t mbSize, s
}
template<class ElemType>
void MultiIOBatchLUSequenceReader<ElemType>::SetSentenceSegBatch(Matrix<ElemType> & sentenceBegin, vector<MinibatchPackingFlag>& minibatchPackingFlag)
void MultiIOBatchLUSequenceReader<ElemType>::SetSentenceSegBatch(Matrix<float> & sentenceBegin, vector<MinibatchPackingFlag>& minibatchPackingFlag)
{
/// run for each reader
vector<size_t> col;

Просмотреть файл

@ -301,7 +301,7 @@ public:
size_t NumberSlicesInEachRecurrentIter();
void SetNbrSlicesEachRecurrentIter(const size_t mz);
void SetSentenceSegBatch(Matrix<ElemType> & sentenceBegin, vector<MinibatchPackingFlag>& minibatchPackingFlag);
void SetSentenceSegBatch(Matrix<float> & sentenceBegin, vector<MinibatchPackingFlag>& minibatchPackingFlag);
public:
void GetClassInfo(LabelInfo& lblInfo);
@ -359,7 +359,7 @@ public:
/// the second data stream has two sentences, with 0 indicating begining of sentences
/// you may use 1 even if a sentence begins at that position, in this case, the trainer will carry over hidden states to the following
/// frame.
Matrix<ElemType> mtSentenceBegin;
Matrix<float> mtSentenceBegin;
/// a matrix of 1 x n_length
/// 1 denotes the case that there exists sentnece begin or no_labels case in this frame
@ -399,7 +399,7 @@ public:
void StartMinibatchLoop(size_t mbSize, size_t epoch, size_t requestedEpochSamples);
void SetSentenceSegBatch(Matrix<ElemType> & sentenceBegin, vector<MinibatchPackingFlag>& minibatchPackingFlag);
void SetSentenceSegBatch(Matrix<float> & sentenceBegin, vector<MinibatchPackingFlag>& minibatchPackingFlag);
size_t NumberSlicesInEachRecurrentIter();

Просмотреть файл

@ -73,7 +73,7 @@
<Link>
<SubSystem>Windows</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
<AdditionalDependencies>CNTKMath.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
<AdditionalDependencies>CNTKMathDll.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
<AdditionalLibraryDirectories>$(SolutionDir)$(Platform)\$(Configuration)\;..\..\math\$(Platform)\$(Configuration);..\$(Platform)\$(Configuration)</AdditionalLibraryDirectories>
</Link>
</ItemDefinitionGroup>
@ -96,7 +96,7 @@
<GenerateDebugInformation>true</GenerateDebugInformation>
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
<AdditionalDependencies>CNTKmath.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
<AdditionalDependencies>CNTKMathDll.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
<AdditionalLibraryDirectories>..\..\math\$(Platform)\$(Configuration);$(SolutionDir)$(Platform)\$(Configuration)\</AdditionalLibraryDirectories>
<Profile>true</Profile>
</Link>

Просмотреть файл

@ -145,7 +145,7 @@ public:
size_t NumberSlicesInEachRecurrentIter() { return 1 ;}
void SetNbrSlicesEachRecurrentIter(const size_t) { };
void SetSentenceSegBatch(Matrix<ElemType> &, vector<MinibatchPackingFlag>& ){};
void SetSentenceSegBatch(Matrix<float> &, vector<MinibatchPackingFlag>&){};
virtual const std::map<LabelIdType, LabelType>& GetLabelMapping(const std::wstring& sectionName);
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, typename LabelType>& labelMapping);
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart=0);

Просмотреть файл

@ -74,7 +74,7 @@
<Link>
<SubSystem>Windows</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
<AdditionalDependencies>CNTKMath.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
<AdditionalDependencies>CNTKMathDll.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
<AdditionalLibraryDirectories>$(SolutionDir)$(Platform)\$(Configuration)\;..\..\math\$(Platform)\$(Configuration);..\$(Platform)\$(Configuration)</AdditionalLibraryDirectories>
</Link>
</ItemDefinitionGroup>
@ -97,7 +97,7 @@
<GenerateDebugInformation>true</GenerateDebugInformation>
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
<AdditionalDependencies>CNTKmath.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
<AdditionalDependencies>CNTKMathDll.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
<AdditionalLibraryDirectories>..\..\math\$(Platform)\$(Configuration);$(SolutionDir)$(Platform)\$(Configuration)\</AdditionalLibraryDirectories>
<Profile>true</Profile>
</Link>

Просмотреть файл

@ -58,7 +58,7 @@ public:
size_t NumberSlicesInEachRecurrentIter() { return 1 ;}
void SetNbrSlicesEachRecurrentIter(const size_t) { };
void SetSentenceSegBatch(Matrix<ElemType> &/*sentenceBegin*/, vector<MinibatchPackingFlag>& /*sentenceExistsBeginOrNoLabels*/) {};
void SetSentenceSegBatch(Matrix<float> &/*sentenceBegin*/, vector<MinibatchPackingFlag>& /*sentenceExistsBeginOrNoLabels*/) {};
virtual const std::map<LabelIdType, LabelType>& GetLabelMapping(const std::wstring& sectionName);
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, typename LabelType>& labelMapping);
virtual bool GetData(const std::wstring& /*sectionName*/, size_t /*numRecords*/, void* /*data*/, size_t& /*dataBufferSize*/, size_t /*recordStart*/) { throw runtime_error("GetData not supported in SparsePCReader"); };

Просмотреть файл

@ -1,151 +1,151 @@
<?xml version="1.0" encoding="utf-8"?>
<Project DefaultTargets="Build" ToolsVersion="12.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup Label="ProjectConfigurations">
<ProjectConfiguration Include="Debug|x64">
<Configuration>Debug</Configuration>
<Platform>x64</Platform>
</ProjectConfiguration>
<ProjectConfiguration Include="Release|x64">
<Configuration>Release</Configuration>
<Platform>x64</Platform>
</ProjectConfiguration>
</ItemGroup>
<PropertyGroup Label="Globals">
<ProjectGuid>{CE429AA2-3778-4619-8FD1-49BA3B81197B}</ProjectGuid>
<SccProjectName>
</SccProjectName>
<SccAuxPath>
</SccAuxPath>
<SccLocalPath>
</SccLocalPath>
<SccProvider>
</SccProvider>
<Keyword>Win32Proj</Keyword>
<RootNamespace>SparsePCReader</RootNamespace>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
<ConfigurationType>DynamicLibrary</ConfigurationType>
<UseDebugLibraries>true</UseDebugLibraries>
<PlatformToolset>v120</PlatformToolset>
<CharacterSet>Unicode</CharacterSet>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
<ConfigurationType>DynamicLibrary</ConfigurationType>
<UseDebugLibraries>false</UseDebugLibraries>
<PlatformToolset>v120</PlatformToolset>
<WholeProgramOptimization>true</WholeProgramOptimization>
<CharacterSet>Unicode</CharacterSet>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
<ImportGroup Label="ExtensionSettings">
</ImportGroup>
<ImportGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="PropertySheets">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<ImportGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="PropertySheets">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<PropertyGroup Label="UserMacros" />
<PropertyGroup Label="UserMacros" />
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<LinkIncremental>true</LinkIncremental>
<IncludePath>..\..\common\include;..\..\math\math;$(VCInstallDir)include;$(VCInstallDir)atlmfc\include;$(WindowsSDK_IncludePath);</IncludePath>
<LibraryPath>$(SolutionDir)$(Platform)\$(Configuration);$(VCInstallDir)lib\amd64;$(VCInstallDir)atlmfc\lib\amd64;$(WindowsSDK_LibraryPath_x64);</LibraryPath>
<IntDir>$(Platform)\$(Configuration)\$(ProjectName)\</IntDir>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<LinkIncremental>false</LinkIncremental>
<IncludePath>c:\Program Files\Microsoft MPI\Inc;..\..\common\include;..\..\math\math;$(VCInstallDir)include;$(VCInstallDir)atlmfc\include;$(WindowsSDK_IncludePath);</IncludePath>
<LibraryPath>c:\Program Files\Microsoft MPI\Lib\amd64;$(SolutionDir)$(Platform)\$(Configuration);$(VCInstallDir)lib\amd64;$(VCInstallDir)atlmfc\lib\amd64;$(WindowsSDK_LibraryPath_x64);</LibraryPath>
<IntDir>$(Platform)\$(Configuration)\$(ProjectName)\</IntDir>
</PropertyGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<ClCompile>
<PrecompiledHeader>NotUsing</PrecompiledHeader>
<WarningLevel>Level4</WarningLevel>
<Optimization>Disabled</Optimization>
<PreprocessorDefinitions>_CRT_SECURE_NO_WARNINGS;WIN32;_DEBUG;_WINDOWS;_USRDLL;SparsePCREADER_EXPORTS;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<SDLCheck>true</SDLCheck>
<AdditionalIncludeDirectories>..\..\common\include;..\..\math\math</AdditionalIncludeDirectories>
<TreatWarningAsError>true</TreatWarningAsError>
<AdditionalOptions>/bigobj %(AdditionalOptions)</AdditionalOptions>
</ClCompile>
<Link>
<SubSystem>Windows</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
<AdditionalDependencies>CNTKMath.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
<AdditionalLibraryDirectories>$(SolutionDir)$(Platform)\$(Configuration)\;..\..\math\$(Platform)\$(Configuration);..\$(Platform)\$(Configuration)</AdditionalLibraryDirectories>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<ClCompile>
<WarningLevel>Level4</WarningLevel>
<PrecompiledHeader>Use</PrecompiledHeader>
<Optimization>MaxSpeed</Optimization>
<FunctionLevelLinking>true</FunctionLevelLinking>
<IntrinsicFunctions>true</IntrinsicFunctions>
<PreprocessorDefinitions>_CRT_SECURE_NO_WARNINGS;WIN32;NDEBUG;_WINDOWS;_USRDLL;SparsePCREADER_EXPORTS;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<SDLCheck>true</SDLCheck>
<AdditionalIncludeDirectories>..\..\common\include;..\..\math\math</AdditionalIncludeDirectories>
<OpenMPSupport>false</OpenMPSupport>
<AdditionalOptions>/d2Zi+ %(AdditionalOptions)</AdditionalOptions>
<TreatWarningAsError>true</TreatWarningAsError>
</ClCompile>
<Link>
<SubSystem>Windows</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
<AdditionalDependencies>CNTKmath.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
<AdditionalLibraryDirectories>..\..\math\$(Platform)\$(Configuration);$(SolutionDir)$(Platform)\$(Configuration)\</AdditionalLibraryDirectories>
<Profile>true</Profile>
</Link>
</ItemDefinitionGroup>
<ItemGroup>
<ClInclude Include="..\..\Common\Include\basetypes.h" />
<ClInclude Include="..\..\Common\Include\DataReader.h" />
<ClInclude Include="..\..\Common\Include\DataWriter.h">
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</ExcludedFromBuild>
</ClInclude>
<ClInclude Include="..\..\Common\Include\File.h">
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</ExcludedFromBuild>
</ClInclude>
<ClInclude Include="..\..\Common\Include\fileutil.h">
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</ExcludedFromBuild>
</ClInclude>
<ClInclude Include="SparsePCReader.h" />
<ClInclude Include="minibatchsourcehelpers.h" />
<ClInclude Include="stdafx.h" />
<ClInclude Include="targetver.h" />
</ItemGroup>
<ItemGroup>
<ClCompile Include="..\..\Common\ConfigFile.cpp">
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">NotUsing</PrecompiledHeader>
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">NotUsing</PrecompiledHeader>
</ClCompile>
<ClCompile Include="..\..\Common\DataReader.cpp" />
<ClCompile Include="..\..\Common\DataWriter.cpp">
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">NotUsing</PrecompiledHeader>
</ClCompile>
<ClCompile Include="..\..\Common\File.cpp">
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">NotUsing</PrecompiledHeader>
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">NotUsing</PrecompiledHeader>
</ClCompile>
<ClCompile Include="..\..\Common\fileutil.cpp">
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">NotUsing</PrecompiledHeader>
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">NotUsing</PrecompiledHeader>
</ClCompile>
<ClCompile Include="dllmain.cpp" />
<ClCompile Include="SparsePCReader.cpp">
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">Use</PrecompiledHeader>
</ClCompile>
<ClCompile Include="Exports.cpp" />
<ClCompile Include="stdafx.cpp">
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">Create</PrecompiledHeader>
</ClCompile>
</ItemGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets">
</ImportGroup>
<?xml version="1.0" encoding="utf-8"?>
<Project DefaultTargets="Build" ToolsVersion="12.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup Label="ProjectConfigurations">
<ProjectConfiguration Include="Debug|x64">
<Configuration>Debug</Configuration>
<Platform>x64</Platform>
</ProjectConfiguration>
<ProjectConfiguration Include="Release|x64">
<Configuration>Release</Configuration>
<Platform>x64</Platform>
</ProjectConfiguration>
</ItemGroup>
<PropertyGroup Label="Globals">
<ProjectGuid>{CE429AA2-3778-4619-8FD1-49BA3B81197B}</ProjectGuid>
<SccProjectName>
</SccProjectName>
<SccAuxPath>
</SccAuxPath>
<SccLocalPath>
</SccLocalPath>
<SccProvider>
</SccProvider>
<Keyword>Win32Proj</Keyword>
<RootNamespace>SparsePCReader</RootNamespace>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
<ConfigurationType>DynamicLibrary</ConfigurationType>
<UseDebugLibraries>true</UseDebugLibraries>
<PlatformToolset>v120</PlatformToolset>
<CharacterSet>Unicode</CharacterSet>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
<ConfigurationType>DynamicLibrary</ConfigurationType>
<UseDebugLibraries>false</UseDebugLibraries>
<PlatformToolset>v120</PlatformToolset>
<WholeProgramOptimization>true</WholeProgramOptimization>
<CharacterSet>Unicode</CharacterSet>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
<ImportGroup Label="ExtensionSettings">
</ImportGroup>
<ImportGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="PropertySheets">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<ImportGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="PropertySheets">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<PropertyGroup Label="UserMacros" />
<PropertyGroup Label="UserMacros" />
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<LinkIncremental>true</LinkIncremental>
<IncludePath>..\..\common\include;..\..\math\math;$(VCInstallDir)include;$(VCInstallDir)atlmfc\include;$(WindowsSDK_IncludePath);</IncludePath>
<LibraryPath>$(SolutionDir)$(Platform)\$(Configuration);$(VCInstallDir)lib\amd64;$(VCInstallDir)atlmfc\lib\amd64;$(WindowsSDK_LibraryPath_x64);</LibraryPath>
<IntDir>$(Platform)\$(Configuration)\$(ProjectName)\</IntDir>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<LinkIncremental>false</LinkIncremental>
<IncludePath>c:\Program Files\Microsoft MPI\Inc;..\..\common\include;..\..\math\math;$(VCInstallDir)include;$(VCInstallDir)atlmfc\include;$(WindowsSDK_IncludePath);</IncludePath>
<LibraryPath>c:\Program Files\Microsoft MPI\Lib\amd64;$(SolutionDir)$(Platform)\$(Configuration);$(VCInstallDir)lib\amd64;$(VCInstallDir)atlmfc\lib\amd64;$(WindowsSDK_LibraryPath_x64);</LibraryPath>
<IntDir>$(Platform)\$(Configuration)\$(ProjectName)\</IntDir>
</PropertyGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<ClCompile>
<PrecompiledHeader>NotUsing</PrecompiledHeader>
<WarningLevel>Level4</WarningLevel>
<Optimization>Disabled</Optimization>
<PreprocessorDefinitions>_CRT_SECURE_NO_WARNINGS;WIN32;_DEBUG;_WINDOWS;_USRDLL;SparsePCREADER_EXPORTS;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<SDLCheck>true</SDLCheck>
<AdditionalIncludeDirectories>..\..\common\include;..\..\math\math</AdditionalIncludeDirectories>
<TreatWarningAsError>true</TreatWarningAsError>
<AdditionalOptions>/bigobj %(AdditionalOptions)</AdditionalOptions>
</ClCompile>
<Link>
<SubSystem>Windows</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
<AdditionalDependencies>CNTKMathDll.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
<AdditionalLibraryDirectories>$(SolutionDir)$(Platform)\$(Configuration)\;..\..\math\$(Platform)\$(Configuration);..\$(Platform)\$(Configuration)</AdditionalLibraryDirectories>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<ClCompile>
<WarningLevel>Level4</WarningLevel>
<PrecompiledHeader>Use</PrecompiledHeader>
<Optimization>MaxSpeed</Optimization>
<FunctionLevelLinking>true</FunctionLevelLinking>
<IntrinsicFunctions>true</IntrinsicFunctions>
<PreprocessorDefinitions>_CRT_SECURE_NO_WARNINGS;WIN32;NDEBUG;_WINDOWS;_USRDLL;SparsePCREADER_EXPORTS;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<SDLCheck>true</SDLCheck>
<AdditionalIncludeDirectories>..\..\common\include;..\..\math\math</AdditionalIncludeDirectories>
<OpenMPSupport>false</OpenMPSupport>
<AdditionalOptions>/d2Zi+ %(AdditionalOptions)</AdditionalOptions>
<TreatWarningAsError>true</TreatWarningAsError>
</ClCompile>
<Link>
<SubSystem>Windows</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
<AdditionalDependencies>CNTKMathDll.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
<AdditionalLibraryDirectories>..\..\math\$(Platform)\$(Configuration);$(SolutionDir)$(Platform)\$(Configuration)\</AdditionalLibraryDirectories>
<Profile>true</Profile>
</Link>
</ItemDefinitionGroup>
<ItemGroup>
<ClInclude Include="..\..\Common\Include\basetypes.h" />
<ClInclude Include="..\..\Common\Include\DataReader.h" />
<ClInclude Include="..\..\Common\Include\DataWriter.h">
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</ExcludedFromBuild>
</ClInclude>
<ClInclude Include="..\..\Common\Include\File.h">
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</ExcludedFromBuild>
</ClInclude>
<ClInclude Include="..\..\Common\Include\fileutil.h">
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</ExcludedFromBuild>
</ClInclude>
<ClInclude Include="SparsePCReader.h" />
<ClInclude Include="minibatchsourcehelpers.h" />
<ClInclude Include="stdafx.h" />
<ClInclude Include="targetver.h" />
</ItemGroup>
<ItemGroup>
<ClCompile Include="..\..\Common\ConfigFile.cpp">
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">NotUsing</PrecompiledHeader>
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">NotUsing</PrecompiledHeader>
</ClCompile>
<ClCompile Include="..\..\Common\DataReader.cpp" />
<ClCompile Include="..\..\Common\DataWriter.cpp">
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">NotUsing</PrecompiledHeader>
</ClCompile>
<ClCompile Include="..\..\Common\File.cpp">
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">NotUsing</PrecompiledHeader>
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">NotUsing</PrecompiledHeader>
</ClCompile>
<ClCompile Include="..\..\Common\fileutil.cpp">
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">NotUsing</PrecompiledHeader>
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">NotUsing</PrecompiledHeader>
</ClCompile>
<ClCompile Include="dllmain.cpp" />
<ClCompile Include="SparsePCReader.cpp">
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">Use</PrecompiledHeader>
</ClCompile>
<ClCompile Include="Exports.cpp" />
<ClCompile Include="stdafx.cpp">
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">Create</PrecompiledHeader>
</ClCompile>
</ItemGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets">
</ImportGroup>
</Project>

Просмотреть файл

@ -112,13 +112,13 @@ public:
virtual bool GetMinibatch(std::map<std::wstring, Matrix<ElemType>*>& matrices);
size_t NumberSlicesInEachRecurrentIter() { return mBlgSize; }
void SetSentenceSegBatch(Matrix<ElemType> &, vector<MinibatchPackingFlag>& ){};
void SetSentenceSegBatch(Matrix<float> &, vector<MinibatchPackingFlag>&){};
virtual const std::map<LabelIdType, LabelType>& GetLabelMapping(const std::wstring& sectionName);
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, LabelType>& labelMapping);
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart=0);
virtual bool DataEnd(EndDataType endDataType);
void SetSentenceSegBatch(Matrix<ElemType>&, Matrix<ElemType>&) { };
void SetSentenceSegBatch(Matrix<float>&, Matrix<ElemType>&) { };
void SetNbrSlicesEachRecurrentIter(const size_t sz);

Просмотреть файл

@ -72,7 +72,7 @@
<Link>
<SubSystem>Windows</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
<AdditionalDependencies>CNTKMath.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
<AdditionalDependencies>CNTKMathDll.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
<AdditionalLibraryDirectories>$(SolutionDir)$(Platform)\$(Configuration)\;..\..\math\$(Platform)\$(Configuration);..\$(Platform)\$(Configuration)</AdditionalLibraryDirectories>
</Link>
</ItemDefinitionGroup>
@ -95,7 +95,7 @@
<GenerateDebugInformation>true</GenerateDebugInformation>
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
<AdditionalDependencies>CNTKmath.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
<AdditionalDependencies>CNTKMathDll.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
<AdditionalLibraryDirectories>..\..\math\$(Platform)\$(Configuration);$(SolutionDir)$(Platform)\$(Configuration)\</AdditionalLibraryDirectories>
<Profile>true</Profile>
</Link>

Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше