Redesign the binary format, refactor and update the reader
  implementation, as well as the ctf2bin.py converter.
This commit is contained in:
Alexey Reznichenko 2017-03-01 13:33:08 +01:00
Родитель b55eaa126a
Коммит ff49f1e3c6
29 изменённых файлов: 862 добавлений и 16459 удалений

Просмотреть файл

@ -5,7 +5,7 @@
#
# The header file must list all of the streams in the input file in the
# following format:
# <desired stream name> TAB <stream alias> TAB <matrix type> TAB <sample dimension>
# <desired stream name> <stream alias> <matrix type> <sample dimension>
#
# Where:
# <desired stream name> is the desired name for the input in CNTK.
@ -16,275 +16,285 @@
import sys
import argparse
import re
import struct
import tempfile
import shutil
import os
from collections import OrderedDict
# This will convert data in the ctf format into binary format
MAGIC_NUMBER = 0x636e746b5f62696e;
CBF_VERSION = 1;
class ElementType:
FLOAT = 0
DOUBLE = 1
class MatrixEncodingType:
DENSE = 0
SPARSE_CSC = 1
#COMPRESSED_DENSE = 2
#COMPRESSED_SPARSE_CSC = 3
# This will convert data in the ctf format into the binary format
class Converter(object):
def __init__(self, name, sampleDim):
def __init__(self, name, sample_dim, element_type):
self.name = name
self.sampleDim = sampleDim
self.vals = list()
self.sample_dim = sample_dim
# contains length (in samples) for each sequence in the chunk
self.sequences = []
self.element_type = element_type
def getName(self):
return self.name
def write_header(self, output):
# First is the matrix type.
output.write(struct.pack('B', self.get_matrix_type()))
# Nest comes the stream name.
output.write(struct.pack('I', len(self.name)))
output.write(self.name.encode('ascii'))
# Next is the elem type
output.write(struct.pack('B', self.element_type))
# Finally, the sample dimension.
output.write(struct.pack('I', self.sample_dim))
def getSampleDim(self):
return self.sampleDim
def write_signed_ints(self, output, ints):
output.write(b''.join([struct.pack('i', x) for x in ints]))
def clear(self):
self.vals = list()
def write_floats(self, output, floats):
format = 'f' if self.is_float() else 'd'
output.write(b''.join([struct.pack(format, x) for x in floats]))
def addSequence(self):
self.vals.append(list())
def is_float(self):
return self.element_type == ElementType.FLOAT
def appendSample(self, sample):
if( len(sample) != self.sampleDim ):
raise ValueError(
"Invalid sample dimension for input {0}".format(self.name))
def get_matrix_type(self):
raise NotImplementedError()
if( len(self.vals) == 0 ):
self.vals.append( list() )
def reset(self):
self.sequences = []
self.vals[-1].append( sample )
def toString(self):
output = ""
for seq in self.vals:
for samp in seq:
output += "\t" + " ".join(samp )
output += "\n"
return output
def start_sequence(self):
self.sequences.append([])
def add_sample(self, sample):
raise NotImplementedError()
# Specilization for dense inputs
class DenseConverter(Converter):
def __init__(self, name, sampleDim):
Converter.__init__(self, name, sampleDim)
def headerBytes(self):
output = bytearray()
# First is the matrix type. Dense is type 0
output += struct.pack( "i", 0 )
# Next is the elem type, currently float only
output += struct.pack( "i", 0 )
# Finally is whether or not this is a sequence
output += struct.pack( "i", self.sampleDim )
return output
def get_matrix_type(self):
return MatrixEncodingType.DENSE;
def toBytes(self):
output = bytearray()
for sequence in self.vals:
if( len(sequence) != 1 ):
raise ValueError("Dense sequences currently not supported.")
def add_sample(self, sample):
if(len(sample) != self.sample_dim):
raise ValueError(
"Invalid sample dimension for input {0}".format(self.name))
for sample in sequence[0]:
output += struct.pack( "f", float(sample) )
byte_size = len(sample) * (4 if self.is_float() else 8)
return output
if(len(self.sequences) == 0):
self.sequences.append([])
byte_size += 4;
self.sequences[-1].append([float(x) for x in sample])
return byte_size
def write_data(self, output):
for sequence in self.sequences:
output.write(struct.pack('I', len(sequence)))
for sample in sequence:
self.write_floats(output, sample)
# Specialization for sparse inputs
class SparseConverter(Converter):
def __init__(self, name, sampleDim):
Converter.__init__(self, name, sampleDim)
def appendSample(self, sample):
for pair in sample:
index = int(pair.split(":")[0])
if (index >= self.sampleDim):
def add_sample(self, sample):
pairs = map(lambda x: (int(x[0]),float(x[1])),
[pair.split(':', 1) for pair in sample])
for pair in pairs:
index = pair[0]
if (index >= self.sample_dim):
raise ValueError("Invalid sample dimension for input {0}. Max {1}, given {2}"
.format(self.name, self.sampleDim, index))
if( len(self.vals) == 0 ):
self.vals.append( list() )
.format(self.name, self.sample_dim, index))
self.vals[-1].append( sample )
byte_size = len(pairs) * (8 if self.is_float() else 12) + 4
def headerBytes(self):
output = bytearray()
# First is the matrix type. Sparse is type 1
output += struct.pack( "i", 1 )
# Next is the storage type, currently sparse csc only
output += struct.pack( "i", 0 )
# Next is the elem type, currently float only
output += struct.pack( "i", 0 )
# Next is whether or not this is a sequence
# Note this is currently ignored
output += struct.pack( "i", 1 )
# Finally is the sample dimension
output += struct.pack( "i", self.sampleDim )
if(len(self.sequences) == 0):
self.sequences.append([])
byte_size += 8;
return output
self.sequences[-1].append(pairs)
def toBytes(self):
output = bytearray()
values = list()
rowInd = list()
colInd = [0]
nnz = 0
for sequence in self.vals:
i = 0
return byte_size
def get_matrix_type(self):
return MatrixEncodingType.SPARSE_CSC;
def write_data(self, output):
format = 'f' if self.is_float() else 'd'
for sequence in self.sequences:
# write out each sequence in csc format
values = []
indices = []
sizes = []
for sample in sequence:
# sort the indices least to greatest
sample.sort(key=lambda x: int(x.split(":")[0]))
for ele in sample:
nnz += 1
ind, val = ele.split(":")
rowInd.append( int(ind) + i * self.sampleDim )
values.append( val )
i += 1
colInd.append( nnz )
sizes.append(len(sample))
sample.sort(key=lambda x: x[0])
for (index, value) in sample:
indices.append(index)
values.append(value)
output += struct.pack( "i", nnz )
output += b''.join( [ struct.pack( "f", float(val) ) for val in values ] )
output += b''.join( [ struct.pack( "i", int(ind) ) for ind in rowInd ] )
output += b''.join( [ struct.pack( "i", int(ind) ) for ind in colInd ] )
output.write(struct.pack('I', len(sequence))) #number of samples in this sequence
# nnz and indices have to be written out as signed ints, since
# this is the index type of the CNTK sparse matrix
output.write(struct.pack('i', len(values))) #total nnz count for this sequence
self.write_floats(output, values)
self.write_signed_ints(output, indices)
self.write_signed_ints(output, sizes)
return output
# Parse an entire sequence given an aliasToId map, and the converters
def ParseSequence( aliasToId, curSequence, converters ):
for des in converters:
des.addSequence()
for line in curSequence:
for input in line.split( "|" )[1:]:
vals = input.split()
# Process the entire sequence
def process_sequence(data, converters, chunk):
byte_size = 0;
for converter in converters.values():
converter.start_sequence()
for line in data:
for input_stream in line.split("|")[1:]:
split = input_stream.split(None, 1)
if (len(split) < 2):
continue
(alias, values) = split
# We need to ignore comments
if( vals[0] != "#" ):
converters[aliasToId[vals[0]]].appendSample( vals[1:] )
return max( [ len(des.vals[ -1 ]) for des in converters ] )
if(len(alias) > 0 and alias[0] != '#'):
byte_size += converters[alias].add_sample(values.split())
sequence_length_samples = max([len(x.sequences[-1]) for x in converters.values()])
chunk.add_sequence(sequence_length_samples)
return byte_size
# Output a binary chunk
def OutputChunk( binfile, converters ):
startPos = binfile.tell()
for des in converters:
binfile.write( des.toBytes() )
des.clear()
return startPos
def write_chunk(binfile, converters, chunk):
binfile.flush()
chunk.offset = binfile.tell()
# write out the number of samples for each sequence in the chunk
binfile.write(b''.join([struct.pack('I', x) for x in chunk.sequences]))
# Get a converter from a type
def GetConverter( inputtype, name, sampleDim ):
converter = None
if( inputtype.lower() == 'dense' ):
converter = DenseConverter( name, sampleDim )
elif( inputtype.lower() == 'sparse' ):
converter = SparseConverter( name, sampleDim )
else:
print('Invalid input format {0}'.format( inputtype ))
sys.exit()
for converter in converters.values():
converter.write_data(binfile)
converter.reset()
# TODO: add a hash of the chunk
return converter
def get_converter(input_type, name, sample_dim, element_type):
if(input_type.lower() == 'dense'):
return DenseConverter(name, sample_dim, element_type)
if(input_type.lower() == 'sparse'):
return SparseConverter(name, sample_dim, element_type)
# Output the binary format header.
def OutputHeader( headerFile, converters ):
# First the version number
headerFile.write( struct.pack( "q", 1 ) )
# Next is the number of chunks, but we don't know what this is, so write a
# placeholder
headerFile.write( struct.pack( "q", 0 ) )
# Finally the number of inputs
headerFile.write( struct.pack( "i", len(converters) ) )
for conv in converters:
# first comes the name. This is common so write it first
headerFile.write( struct.pack( "i", len( conv.getName() ) ) )
headerFile.write( conv.getName().encode('ascii') )
headerFile.write( conv.headerBytes() )
# At the end we know how many chunks there are. Update the header as needed.
def UpdateHeader( headerFile, numChunks ):
# seek after the first Int64
headerFile.seek( 8 )
# Write the number of chunks
headerFile.write( struct.pack( "q", numChunks ) )
raise ValueError('Invalid input format {0}'.format(input_type))
# parse the header to get the converters for this file
# <name> <alias> <input format> <sample size>
def build_converters(header_file, element_type):
converters = OrderedDict();
with open(header_file, 'r') as inputs:
for line in inputs:
(name, alias, input_type, sample_dim) = line.strip().split()
converters[alias] = get_converter(input_type, name, int(sample_dim), element_type)
return converters
class Chunk:
def __init__(self):
self.offset = 0
self.sequences = []
def num_sequences(self):
return len(self.sequences)
def num_samples(self):
return sum(self.sequences)
def add_sequence(self, num_samples):
return self.sequences.append(num_samples)
class Header:
def __init__(self, converters):
self.converters = converters
self.chunks = []
def add_chunk(self, chunk):
assert(isinstance(chunk, Chunk))
self.chunks.append(chunk)
# Output the binary format header.
def write(self, output_file):
output_file.flush()
header_offset = output_file.tell()
# First, write the magic number (uint64, 8 bytes)
output_file.write(struct.pack('Q', MAGIC_NUMBER));
# Next is the number of chunks (uint32, 4 bytes)
output_file.write(struct.pack('I', len(self.chunks)))
# Finally the number of input streams (uint32, 4 bytes)
output_file.write(struct.pack('I', len(self.converters)))
for converter in self.converters.values():
converter.write_header(output_file)
# write the chunk table
for chunk in self.chunks:
# uint64: start offset for chunk
output_file.write(struct.pack('q', chunk.offset))
# uint32: number of sequences in the chunk
output_file.write(struct.pack('I', chunk.num_sequences()))
# uint32: number of samples in the chunk
output_file.write(struct.pack('I', chunk.num_samples()))
output_file.write(struct.pack('q', header_offset));
# Output a single row of the offsets table
def OutputOffset( headerFile, numBytes, numSeqs, numSamples ):
# Int64 start offset for chunk
headerFile.write( struct.pack( "q", numBytes ) )
# Int32 Num sequences in the chunk
headerFile.write( struct.pack( "i", numSeqs ) )
# Int32 Num samples in the chunk
headerFile.write( struct.pack( "i", numSamples ) )
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Transforms a CNTK Text Format file into CNTK binary format given a header.")
parser.add_argument('--input', help="CNTK Text Format file to convert to binary.", default="", required=True)
parser.add_argument('--header', help="Header file describing each stream in the input.", default="", required=True)
parser.add_argument('--seqsPerChunk', type=int, help='Number of sequences in each chunk.', default="", required=True)
parser.add_argument('--output', help='Name of the output file, stdout if not given', default="", required=True)
parser.add_argument('--input', help="CNTK Text Format file to convert to binary.", required=True)
parser.add_argument('--header', help="Header file describing each stream in the input.", required=True)
parser.add_argument('--chunk_size', type=int, help='Chunk size in bytes.', required=True)
parser.add_argument('--output', help='Name of the output file, stdout if not given', required=True)
parser.add_argument('--precision', help='Floating point precision (double or float). Default is float',
choices=["float", "double"], default="float", required=False)
args = parser.parse_args()
# Since we don't know how many chunks we're going to write until we're done,
# grow the header/offsets table and the data portion separately. then at the
# end concatenate the data portion onto the end of the header/offsets
# portion.
binaryHeaderFile = open( args.output, "wb+" )
binaryDataFile = tempfile.NamedTemporaryFile(mode="rb+", delete=False)
dataPath = binaryDataFile.name
output = open(args.output, "wb")
# The very first 8 bytes of the file is the CBF magic number.
output.write(struct.pack('Q', MAGIC_NUMBER));
# Next 4 bytes is the CBF version.
output.write(struct.pack('I', CBF_VERSION));
# parse the header to get the converters for this file
# <name> <alias> <input format> <sample size>
converters = []
aliasToId = dict()
with open( args.header, "r" ) as headerfile:
id = 0
for line in headerfile:
split = re.split(r'\t+', line.strip())
converters.append( GetConverter( split[ 2 ], split[ 0 ], int(split[3]) ) )
aliasToId[ split[ 1 ] ] = id
id += 1
converters = build_converters(args.header,
ElementType.FLOAT if args.precision == 'float' else ElementType.DOUBLE)
OutputHeader( binaryHeaderFile, converters )
numChunks = 0
with open( args.input, "r" ) as inputFile:
curSequence = list()
numSeqs = 0
numSamps = 0
prevId = None
for line in inputFile:
split = line.rstrip().split('|')
header = Header(converters)
chunk = Chunk()
with open(args.input, "r") as input_file:
sequence = []
seq_id = None
estimated_chunk_size = 0
for line in input_file:
(prefix, _) = line.rstrip().split('|',1)
# if the sequence id is empty or not equal to the previous sequence id,
# we are at a new sequence.
if( not split[0] or prevId != split[ 0 ] ):
if(len(curSequence) > 0):
numSamps += ParseSequence( aliasToId, curSequence, converters )
curSequence = list()
numSeqs += 1
if( numSeqs % int( args.seqsPerChunk ) == 0 ):
numBytes = OutputChunk( binaryDataFile, converters )
numChunks += 1
OutputOffset( binaryHeaderFile, numBytes, numSeqs, numSamps )
numSeqs = 0
numSamps = 0
prevId = split[ 0 ]
if((not seq_id and not prefix) or (len(prefix) > 0 and seq_id != prefix)):
if(len(sequence) > 0):
estimated_chunk_size += process_sequence(sequence, converters, chunk)
sequence = []
if(estimated_chunk_size >= int(args.chunk_size)):
write_chunk(output, converters, chunk)
header.add_chunk(chunk)
chunk = Chunk()
seq_id = prefix
curSequence.append( line )
sequence.append(line)
# we must parse the last line
if( len(curSequence) > 0 ):
numSamps += ParseSequence( aliasToId, curSequence, converters )
numSeqs += 1
numChunks += 1
if(len(sequence) > 0):
process_sequence(sequence, converters, chunk)
numBytes = OutputChunk( binaryDataFile, converters )
OutputOffset( binaryHeaderFile, numBytes, numSeqs, numSamps )
write_chunk(output, converters, chunk)
header.add_chunk(chunk)
UpdateHeader( binaryHeaderFile, numChunks )
binaryHeaderFile.flush()
binaryDataFile.flush()
binaryHeaderFile.close()
binaryDataFile.close()
header.write(output)
destination = open( args.output, 'ab+' )
shutil.copyfileobj( open( dataPath, "rb" ), destination )
destination.flush()
destination.close()
os.unlink(dataPath)
assert not os.path.exists(dataPath)
output.flush()
output.close()

Просмотреть файл

@ -4,6 +4,8 @@
//
#include "stdafx.h"
#define __STDC_FORMAT_MACROS
#include <inttypes.h>
#include "BinaryChunkDeserializer.h"
#include "BinaryDataChunk.h"
#include "FileHelper.h"
@ -11,43 +13,52 @@
namespace Microsoft { namespace MSR { namespace CNTK {
enum class DeserializerType : int32_t
enum class MatrixEncodingType : unsigned char
{
DenseBinaryDataDeserializer = 0,
SparseBinaryDataDeserializer = 1
dense = 0,
sparse_csc = 1,
// TODO: compressed_sparse_csc = 2, // indices are encoded as var-ints
};
void BinaryChunkDeserializer::ReadOffsetsTable(FILE* infile)
void BinaryChunkDeserializer::ReadChunkTable(FILE* infile)
{
ReadOffsetsTable(infile, 0, m_numChunks);
ReadChunkTable(infile, 0, m_numChunks);
}
void BinaryChunkDeserializer::ReadOffsetsTable(FILE* infile, size_t startOffset, size_t numChunks)
void BinaryChunkDeserializer::ReadChunkTable(FILE* infile, uint32_t firstChunkIdx, uint32_t numChunks)
{
assert((int64_t)(startOffset + numChunks) <= m_numChunks);
size_t startPos = startOffset * sizeof(DiskOffsetsTable) + m_offsetStart;
if (firstChunkIdx + numChunks > m_numChunks)
{
RuntimeError("Requested chunks (from %" PRIu32 " to %" PRIu32 ") are out of bounds "
"(the total number of chunks in the dataset is %" PRIu32 ").",
firstChunkIdx, (firstChunkIdx + numChunks - 1), m_numChunks);
}
// Seek to the offsets table start
CNTKBinaryFileHelper::seekOrDie(infile, startPos, SEEK_SET);
uint64_t firstChunkOffset = firstChunkIdx * sizeof(ChunkInfo) + m_chunkTableOffset;
// Seek to the start of the offset info for the first requested chunk
CNTKBinaryFileHelper::SeekOrDie(infile, firstChunkOffset, SEEK_SET);
// Note we create numChunks + 1 since we want to be consistent with determining the size of each chunk.
DiskOffsetsTable* offsetsTable = new DiskOffsetsTable[numChunks + 1];
ChunkInfo* chunks = new ChunkInfo[numChunks + 1];
// Read in all of the offsets for the chunks of interest
CNTKBinaryFileHelper::readOrDie(offsetsTable, sizeof(DiskOffsetsTable), numChunks, infile);
CNTKBinaryFileHelper::ReadOrDie(chunks, sizeof(ChunkInfo), numChunks, infile);
// Now read the final entry. It is either the next offset entry (if we're reading a subset and the
// entry exists), or we just fill it with the correct information based on file size if it doesn't
if ((int64_t)(startOffset + numChunks) == m_numChunks)
if (firstChunkIdx + numChunks == m_numChunks)
{
CNTKBinaryFileHelper::seekOrDie(infile, 0, SEEK_END);
offsetsTable[numChunks].offset = CNTKBinaryFileHelper::tellOrDie(infile) - m_dataStart;
offsetsTable[numChunks].numSamples = 0;
offsetsTable[numChunks].numSequences = 0;
auto position = CNTKBinaryFileHelper::TellOrDie(infile);
chunks[numChunks].offset = position;
chunks[numChunks].numSamples = 0;
chunks[numChunks].numSequences = 0;
}
else
CNTKBinaryFileHelper::readOrDie(offsetsTable + numChunks, sizeof(DiskOffsetsTable), 1, infile);
CNTKBinaryFileHelper::ReadOrDie(chunks + numChunks, sizeof(ChunkInfo), 1, infile);
m_offsetsTable = make_unique<OffsetsTable>(numChunks, offsetsTable);
m_chunkTable = make_unique<ChunkTable>(numChunks, chunks);
}
@ -56,7 +67,7 @@ BinaryChunkDeserializer::BinaryChunkDeserializer(const BinaryConfigHelper& helpe
{
SetTraceLevel(helper.GetTraceLevel());
Initialize(helper.GetRename());
Initialize(helper.GetRename(), helper.GetElementType());
}
@ -64,8 +75,8 @@ BinaryChunkDeserializer::BinaryChunkDeserializer(const std::wstring& filename) :
DataDeserializerBase(true),
m_filename(filename),
m_file(nullptr),
m_offsetStart(0),
m_dataStart(0),
m_headerOffset(0),
m_chunkTableOffset(0),
m_traceLevel(0)
{
}
@ -73,103 +84,85 @@ BinaryChunkDeserializer::BinaryChunkDeserializer(const std::wstring& filename) :
BinaryChunkDeserializer::~BinaryChunkDeserializer()
{
if (m_file)
CNTKBinaryFileHelper::closeOrDie(m_file);
CNTKBinaryFileHelper::CloseOrDie(m_file);
}
void BinaryChunkDeserializer::Initialize(const std::map<std::wstring, std::wstring>& rename)
void BinaryChunkDeserializer::Initialize(const std::map<std::wstring, std::wstring>& rename, ElementType precision)
{
if (m_file)
CNTKBinaryFileHelper::closeOrDie(m_file);
CNTKBinaryFileHelper::CloseOrDie(m_file);
m_file = CNTKBinaryFileHelper::openOrDie(m_filename, L"rb");
m_file = CNTKBinaryFileHelper::OpenOrDie(m_filename, L"rb");
// We are now parsing the header. Seek to the head of the header to start.
CNTKBinaryFileHelper::seekOrDie(m_file, 0, SEEK_SET);
// First, verify the magic number.
CNTKBinaryFileHelper::FindMagicOrDie(m_file, m_filename);
// Second, read the version number of the data file, and (for now) make sure the reader version is the same.
uint32_t versionNumber = CNTKBinaryFileHelper::GetVersionNumber(m_file);
if (versionNumber != s_currentVersion)
LogicError("The reader version is %" PRIu32 ", but the data file was created for version %" PRIu32 ".",
s_currentVersion, versionNumber);
// First read the version number of the data file, and make sure the reader version is the same.
int64_t versionNumber;
CNTKBinaryFileHelper::readOrDie(&versionNumber, sizeof(versionNumber), 1, m_file);
if (versionNumber != m_versionNumber)
LogicError("The reader version is %d, but the data file was created for version %d.", (int)m_versionNumber, (int)versionNumber);
// Now, find where the header is.
m_headerOffset = CNTKBinaryFileHelper::GetHeaderOffset(m_file);
CNTKBinaryFileHelper::SeekOrDie(m_file, m_headerOffset, SEEK_SET);
// Once again, make sure that the header is well-formed and starts with a magic number.
CNTKBinaryFileHelper::FindMagicOrDie(m_file, m_filename);
// Next is the number of chunks in the input file.
CNTKBinaryFileHelper::readOrDie(&m_numChunks, sizeof(m_numChunks), 1, m_file);
CNTKBinaryFileHelper::ReadOrDie(&m_numChunks, sizeof(m_numChunks), 1, m_file);
// Next is the number of inputs
CNTKBinaryFileHelper::readOrDie(&m_numInputs, sizeof(m_numInputs), 1, m_file);
CNTKBinaryFileHelper::ReadOrDie(&m_numInputs, sizeof(m_numInputs), 1, m_file);
// Reserve space for all of the inputs, and then read them in.
m_streams.resize(m_numInputs);
m_deserializers.resize(m_numInputs);
int32_t len;
// 100 characters should be plenty by default, but grow if necessary.
vector<char> tempName(100);
for (int32_t c = 0; c < m_numInputs; c++)
for (decltype(m_numInputs) i = 0; i < m_numInputs; i++)
{
// Create our streamDescription for this input
auto streamDescription = std::make_shared<StreamDescription>();
MatrixEncodingType type;
CNTKBinaryFileHelper::ReadOrDie(&type, sizeof(type), 1, m_file);
if (type == MatrixEncodingType::dense)
m_deserializers[i] = make_shared<DenseBinaryDataDeserializer>(m_file, precision);
else if (type == MatrixEncodingType::sparse_csc)
m_deserializers[i] = make_shared<SparseBinaryDataDeserializer>(m_file, precision);
else
RuntimeError("Unknown encoding type %u requested.", type);
// read the name
CNTKBinaryFileHelper::readOrDie(&len, sizeof(len), 1, m_file);
// Need 1 extra char for the null.
tempName.resize(len+1);
CNTKBinaryFileHelper::readOrDie(tempName.data(), sizeof(char), len, m_file);
tempName[len] = '\0';
wstring wname = msra::strfun::utf16(tempName.data());
auto description = m_deserializers[i]->GetStreamDescription();
description->m_id = i;
// Check if we should rename this input based on the config
if (rename.find(wname) == rename.end())
streamDescription->m_name = wname;
else
streamDescription->m_name = rename.at(wname);
auto it = rename.find(description->m_name);
if (it != rename.end())
{
description->m_name = it->second;
}
// Read the matrix type. Then instantiate the appropriate BinaryDataDeserializer, and have it read in its parameters
// Note: Is there a better way to do this?
DeserializerType desType;
CNTKBinaryFileHelper::readOrDie(&desType, sizeof(desType), 1, m_file);
if (desType == DeserializerType::DenseBinaryDataDeserializer)
m_deserializers[c] = make_shared<DenseBinaryDataDeserializer>(m_file);
else if (desType == DeserializerType::SparseBinaryDataDeserializer)
m_deserializers[c] = make_shared<SparseBinaryDataDeserializer>(m_file);
else
RuntimeError("Unknown deserializer type %d requested.", (int)desType);
streamDescription->m_id = c;
streamDescription->m_elementType = m_deserializers[c]->GetElementType();
streamDescription->m_storageType = m_deserializers[c]->GetStorageType();
streamDescription->m_sampleLayout = m_deserializers[c]->GetSampleLayout();
m_streams[c] = streamDescription;
m_streams[i] = description;
}
// We just finished the header. So we're now at the offsets table.
m_offsetStart = CNTKBinaryFileHelper::tellOrDie(m_file);
// After the header is the data start. Compute that now.
m_dataStart = m_offsetStart + m_numChunks * sizeof(DiskOffsetsTable);
// We just finished the header. So we're now at the chunk table.
m_chunkTableOffset = CNTKBinaryFileHelper::TellOrDie(m_file);
// We only have to read in the offsets table once, so do that now.
// Note it's possible in distributed reading mode to only want to read
// a subset of the offsets table.
ReadOffsetsTable(m_file);
ReadChunkTable(m_file);
}
ChunkDescriptions BinaryChunkDeserializer::GetChunkDescriptions()
{
assert(m_offsetsTable);
assert(m_chunkTable);
ChunkDescriptions result;
result.reserve(m_numChunks);
if (m_numChunks > CHUNKID_MAX)
RuntimeError("Currently CNTK does not support %d chunks. The maximum number of chunks allowed is %d.", (int)m_numChunks, (int)CHUNKID_MAX);
for (ChunkIdType c = 0; c < (ChunkIdType)m_numChunks; c++ )
for (ChunkIdType i = 0; i < m_numChunks; i++ )
{
result.push_back(shared_ptr<ChunkDescription>(
new ChunkDescription {
c,
(size_t)m_offsetsTable->GetNumSamples(c),
(size_t)m_offsetsTable->GetNumSequences(c)
}));
new ChunkDescription{ i, m_chunkTable->GetNumSamples(i), m_chunkTable->GetNumSequences(i) }));
}
return result;
@ -178,33 +171,25 @@ ChunkDescriptions BinaryChunkDeserializer::GetChunkDescriptions()
void BinaryChunkDeserializer::GetSequencesForChunk(ChunkIdType chunkId, std::vector<SequenceDescription>& result)
{
// Reserve space for each sequence
result.reserve(m_offsetsTable->GetNumSequences(chunkId));
result.reserve(m_chunkTable->GetNumSequences(chunkId));
// We don't store every piece of sequence information, so we have to read the chunk in, parse it, and then
// find the information.
// BUGBUG: Note this requires reading each chunk twice. This might not be hugely disadvantageous due to OS
// caching, but should be avoided none the less.
ChunkPtr chunk = GetChunk(chunkId);
auto offset = m_chunkTable->GetOffset(chunkId);
auto numberOfSequences = m_chunkTable->GetNumSequences(chunkId);
unique_ptr<uint32_t[]> numSamplesPerSequence(new uint32_t[numberOfSequences]);
size_t startId = m_offsetsTable->GetStartIndex(chunkId);
std::vector<SequenceDataPtr> temp;
for (size_t c = 0; c < m_offsetsTable->GetNumSequences(chunkId); c++)
// Seek to the start of the chunk
CNTKBinaryFileHelper::SeekOrDie(m_file, offset, SEEK_SET);
// read 'numberOfSequences' unsigned ints
CNTKBinaryFileHelper::ReadOrDie(numSamplesPerSequence.get(), sizeof(uint32_t), numberOfSequences, m_file);
auto startId = m_chunkTable->GetStartIndex(chunkId);
for (decltype(numberOfSequences) i = 0; i < numberOfSequences; i++)
{
// BUGBUG: This is inefficient, but we don't have a choice. Why do we need this at all? Why can't
// this information just be gotten from the chunks? It's not clear.
// Note numSamples is 1 if there are no sequences.
uint32_t numSamples = 1;
temp.clear();
chunk->GetSequence(m_offsetsTable->GetStartIndex(chunkId) + c, temp);
// Only take the max over streams that are actually in use.
for (size_t i = 0; i < temp.size(); i++)
numSamples = max(numSamples, temp[i]->m_numberOfSamples);
SequenceDescription sd = {};
sd.m_indexInChunk = startId + c;
sd.m_numberOfSamples = numSamples;
sd.m_indexInChunk = i;
sd.m_numberOfSamples = numSamplesPerSequence[i];
sd.m_chunkId = chunkId;
sd.m_key.m_sequence = startId + c;
sd.m_key.m_sequence = startId + i;
sd.m_key.m_sample = 0;
result.push_back(sd);
@ -213,17 +198,18 @@ void BinaryChunkDeserializer::GetSequencesForChunk(ChunkIdType chunkId, std::vec
unique_ptr<byte[]> BinaryChunkDeserializer::ReadChunk(ChunkIdType chunkId)
{
// Seek to the start of the chunk
CNTKBinaryFileHelper::seekOrDie(m_file, m_dataStart + m_offsetsTable->GetOffset(chunkId), SEEK_SET);
// Seek to the start of the data portion in the chunk
CNTKBinaryFileHelper::SeekOrDie(m_file, m_chunkTable->GetDataStartOffset(chunkId), SEEK_SET);
// Determine how big the chunk is.
size_t chunkSize = m_offsetsTable->GetChunkSize(chunkId);
size_t chunkSize = m_chunkTable->GetChunkSize(chunkId);
// Create buffer
// TODO: use a pool of buffers instead of allocating a new one, each time a chunk is read.
unique_ptr<byte[]> buffer(new byte[chunkSize]);
// Read the chunk from disk
CNTKBinaryFileHelper::readOrDie(buffer.get(), sizeof(byte), chunkSize, m_file);
CNTKBinaryFileHelper::ReadOrDie(buffer.get(), sizeof(byte), chunkSize, m_file);
return buffer;
}
@ -232,9 +218,9 @@ unique_ptr<byte[]> BinaryChunkDeserializer::ReadChunk(ChunkIdType chunkId)
ChunkPtr BinaryChunkDeserializer::GetChunk(ChunkIdType chunkId)
{
// Read the chunk into memory
unique_ptr<byte[]> chunkBuffer = ReadChunk(chunkId);
unique_ptr<byte[]> buffer = ReadChunk(chunkId);
return make_shared<BinaryDataChunk>(chunkId, m_offsetsTable->GetStartIndex(chunkId), m_offsetsTable->GetNumSequences(chunkId), std::move(chunkBuffer), m_deserializers);
return make_shared<BinaryDataChunk>(chunkId, m_chunkTable->GetNumSequences(chunkId), std::move(buffer), m_deserializers);
}
void BinaryChunkDeserializer::SetTraceLevel(unsigned int traceLevel)

Просмотреть файл

@ -13,55 +13,78 @@
namespace Microsoft { namespace MSR { namespace CNTK {
// Offsets table used to find the chunks in the binary file. Added some helper methods around the core data.
#pragma pack(push, 1)
struct DiskOffsetsTable
// Chunk meta-info: byte offset in the inputfile, number of sequences and samples in the chunk.
struct ChunkInfo
{
int64_t offset;
int32_t numSequences;
int32_t numSamples;
uint32_t numSequences;
uint32_t numSamples;
};
#pragma pack(pop)
// Offsets table used to find the chunks in the binary file. Added some helper methods around the core data.
class OffsetsTable {
// Chunk table used to find the chunks in the binary file. Added some helper methods around the core data.
class ChunkTable {
public:
OffsetsTable(size_t numChunks, DiskOffsetsTable* offsetsTable) : m_numChunks(numChunks)
ChunkTable(uint32_t numChunks, ChunkInfo* offsetsTable) :
m_numChunks(numChunks),
m_diskOffsetsTable(offsetsTable),
m_startIndex(numChunks)
{
m_diskOffsetsTable = make_unique<DiskOffsetsTable*>(offsetsTable);
Initialize();
uint64_t numSequences = 0;
for (decltype(m_numChunks) i = 0; i < m_numChunks; i++)
{
m_startIndex[i] = numSequences;
numSequences += m_diskOffsetsTable[i].numSequences;
}
}
int64_t GetOffset(size_t index) { return (*m_diskOffsetsTable)[index].offset; }
int32_t GetNumSequences(size_t index) { return (*m_diskOffsetsTable)[index].numSequences; }
int32_t GetNumSamples(size_t index) { return (*m_diskOffsetsTable)[index].numSamples; }
int64_t GetStartIndex(size_t index) { return m_startIndex[index]; }
size_t GetChunkSize(size_t index) { return (*m_diskOffsetsTable)[index + 1].offset - (*m_diskOffsetsTable)[index].offset; }
int64_t GetOffset(uint32_t index)
{
return m_diskOffsetsTable[index].offset;
}
private:
void Initialize()
int64_t GetDataStartOffset(uint32_t index)
{
m_startIndex.resize(m_numChunks);
m_startIndex[0] = 0;
for (int64_t c = 1; c < m_numChunks; c++)
m_startIndex[c] = m_startIndex[c-1] + (*m_diskOffsetsTable)[c].numSequences;
auto sequenceLengthPrefix = GetNumSequences(index) * sizeof(uint32_t);
return GetOffset(index) + sequenceLengthPrefix;
}
uint32_t GetNumSequences(uint32_t index)
{
return m_diskOffsetsTable[index].numSequences;
}
uint32_t GetNumSamples(uint32_t index)
{
return m_diskOffsetsTable[index].numSamples;
}
int64_t GetStartIndex(uint32_t index)
{
return m_startIndex.at(index);
}
uint64_t GetChunkSize(uint32_t index)
{
auto dataStartOffset = GetDataStartOffset(index);
auto dataEndOffset = GetOffset(index + 1);
return dataEndOffset - dataStartOffset;
}
private:
int64_t m_numChunks;
unique_ptr<DiskOffsetsTable*> m_diskOffsetsTable;
vector<size_t> m_startIndex;
uint32_t m_numChunks;
unique_ptr<ChunkInfo[]> m_diskOffsetsTable;
vector<uint64_t> m_startIndex;
};
typedef unique_ptr<OffsetsTable> OffsetsTablePtr;
typedef unique_ptr<ChunkTable> ChunkTablePtr;
// TODO: more details when tracing warnings
class BinaryChunkDeserializer : public DataDeserializerBase {
public:
explicit BinaryChunkDeserializer(const BinaryConfigHelper& helper);
BinaryChunkDeserializer(CorpusDescriptorPtr corpus, const BinaryConfigHelper& helper);
BinaryChunkDeserializer(CorpusDescriptorPtr corpus, const BinaryConfigHelper& helper) = delete;
~BinaryChunkDeserializer();
@ -74,16 +97,13 @@ public:
// Get information about particular chunk.
void GetSequencesForChunk(ChunkIdType chunkId, vector<SequenceDescription>& result) override;
// Parses buffer into a BinaryChunkPtr
void ParseChunk(ChunkIdType chunkId, unique_ptr<byte[]> const& buffer, std::vector<std::vector<SequenceDataPtr>>& data);
private:
// Builds an index of the input data.
void Initialize(const std::map<std::wstring, std::wstring>& rename);
void Initialize(const std::map<std::wstring, std::wstring>& rename, ElementType precision);
// Reads the offsets table from disk into memory
void ReadOffsetsTable(FILE* infile, size_t startOffset, size_t numChunks);
void ReadOffsetsTable(FILE* infile);
// Reads the chunk table from disk into memory
void ReadChunkTable(FILE* infile, uint32_t firstChunkIdx, uint32_t numChunks);
void ReadChunkTable(FILE* infile);
// Reads a chunk from disk into buffer
unique_ptr<byte[]> ReadChunk(ChunkIdType chunkId);
@ -96,22 +116,23 @@ private:
const wstring m_filename;
FILE* m_file;
int64_t m_offsetStart;
int64_t m_dataStart;
int64_t m_headerOffset, m_chunkTableOffset;
std::vector<BinaryDataDeserializerPtr> m_deserializers;
OffsetsTablePtr m_offsetsTable;
ChunkTablePtr m_chunkTable;
void* m_chunkBuffer;
int64_t m_versionNumber = 1;
int64_t m_numChunks;
int32_t m_numInputs;
uint32_t m_numChunks;
uint32_t m_numInputs;
unsigned int m_traceLevel;
static const uint32_t s_currentVersion = 1;
friend class CNTKBinaryReaderTestRunner;
DISABLE_COPY_AND_MOVE(BinaryChunkDeserializer);
};
}}}

Просмотреть файл

@ -48,6 +48,20 @@ namespace Microsoft { namespace MSR { namespace CNTK {
}
}
string precision = config.Find("precision", "float");
if (AreEqualIgnoreCase(precision, "double"))
{
m_elementType = ElementType::tdouble;
}
else if (AreEqualIgnoreCase(precision, "float"))
{
m_elementType = ElementType::tfloat;
}
else
{
RuntimeError("Not supported precision '%s'. Expected 'double' or 'float'.", precision.c_str());
}
m_filepath = msra::strfun::utf16(config(L"file"));
m_keepDataInMemory = config(L"keepDataInMemory", false);

Просмотреть файл

@ -34,11 +34,14 @@ public:
bool ShouldKeepDataInMemory() const { return m_keepDataInMemory; }
ElementType GetElementType() const { return m_elementType; }
DISABLE_COPY_AND_MOVE(BinaryConfigHelper);
private:
std::wstring m_filepath;
std::map<std::wstring, std::wstring> m_streams;
ElementType m_elementType;
size_t m_randomizationWindow;
// Specifies how to interpret randomization window, if true randomization window == number of samples, else
// randomization window = number of chunks (default).

Просмотреть файл

@ -15,30 +15,37 @@ namespace Microsoft { namespace MSR { namespace CNTK {
class BinaryDataChunk : public Chunk, public std::enable_shared_from_this<Chunk>
{
public:
explicit BinaryDataChunk(ChunkIdType chunkId, size_t startSequence, size_t numSequences, unique_ptr<byte[]> buffer, std::vector<BinaryDataDeserializerPtr> deserializer)
: m_chunkId(chunkId), m_startSequence(startSequence), m_numSequences(numSequences), m_buffer(std::move(buffer)), m_deserializers(deserializer)
{
}
explicit BinaryDataChunk(ChunkIdType chunkId,
size_t numSequences,
unique_ptr<byte[]> buffer,
std::vector<BinaryDataDeserializerPtr> deserializer)
: m_chunkId(chunkId),
m_numSequences(numSequences),
m_buffer(std::move(buffer)),
m_deserializers(deserializer)
{ }
// Gets sequences by id.
void GetSequence(size_t sequenceId, std::vector<SequenceDataPtr>& result) override
// Gets a sequence using its index inside the chunk.
void GetSequence(size_t sequenceIdx, std::vector<SequenceDataPtr>& result) override
{
// Check if we've already parsed the chunk. If not, parse it.
if (m_data.size() == 0)
ParseChunk();
assert(m_data.size() != 0);
// resize the output to have the same dimensionality
result.resize(m_data.size());
// now copy the decoded sequences
for (size_t c = 0; c < m_data.size(); c++)
result[c] = m_data[c].at(sequenceId - m_startSequence);
for (size_t i = 0; i < m_data.size(); i++)
result[i] = m_data[i].at(sequenceIdx);
}
uint32_t GetNumSamples(size_t sequenceId)
uint32_t GetNumSamples(size_t sequenceIdx)
{
uint32_t numSamples = 0;
for (size_t c = 0; c < m_data.size(); c++)
numSamples = max(numSamples, m_data[c].at(sequenceId)->m_numberOfSamples);
for (size_t i = 0; i < m_data.size(); i++)
numSamples = max(numSamples, m_data[i].at(sequenceIdx)->m_numberOfSamples);
return numSamples;
}
@ -50,15 +57,13 @@ protected:
// the number of bytes of buffer that have been processed by the deserializer so far
size_t bytesProcessed = 0;
// Now call all of the deserializers on the chunk, in order
for (size_t c = 0; c < m_deserializers.size(); c++)
bytesProcessed += m_deserializers[c]->GetSequenceDataForChunk(m_numSequences, (byte*)m_buffer.get() + bytesProcessed, m_data[c]);
for (size_t i = 0; i < m_deserializers.size(); i++)
bytesProcessed += m_deserializers[i]->GetSequenceDataForChunk(m_numSequences, m_buffer.get() + bytesProcessed, m_data[i]);
}
// chunk id (copied from the descriptor)
ChunkIdType m_chunkId;
// start id for sequences in this chunk.
size_t m_startSequence;
// num sequences in this chunk. Note this should be in the chunk, but for simplicity it is in the offsets table
// so we must tell the chunk where it starts.
size_t m_numSequences;

Просмотреть файл

@ -14,40 +14,102 @@
namespace Microsoft { namespace MSR { namespace CNTK {
class BinaryDataDeserialzer {
class BinaryDataDeserialzer
{
public:
BinaryDataDeserialzer(FILE* file, ElementType precision = ElementType::tfloat)
{
ReadName(file);
ReadDataType(file);
ReadSampleSize(file);
if (precision != ElementType::tfloat && precision != ElementType::tdouble)
LogicError("Unsupported precision type %u.", precision);
if ((m_dataType == DataType::tfloat && precision != ElementType::tfloat) ||
(m_dataType == DataType::tdouble && precision != ElementType::tdouble))
LogicError("Unsupported combination of the input data type %u and precision %u. "
"At the moment, both have to match.", m_dataType, precision);
m_precision = precision;
}
virtual size_t GetSequenceDataForChunk(size_t numSequences, void* data, std::vector<SequenceDataPtr>& result) = 0;
StorageType GetStorageType() { return m_storageType; }
ElementType GetElementType() { return m_elemType; }
TensorShapePtr GetSampleLayout() { return make_shared<TensorShape>(m_numCols); }
virtual bool IsSequence() { return false; }
virtual StorageType GetStorageType() = 0;
size_t GetElemSizeBytes()
StreamDescriptionPtr GetStreamDescription()
{
if (m_elemType == ElementType::tfloat)
auto streamDescription = std::make_shared<StreamDescription>();
streamDescription->m_elementType = m_precision;
streamDescription->m_storageType = GetStorageType();
streamDescription->m_sampleLayout = GetSampleLayout();
streamDescription->m_name = m_name;
return streamDescription;
}
TensorShapePtr GetSampleLayout()
{
return make_shared<TensorShape>(m_sampleDimension);
}
size_t SizeOfDataType()
{
if (m_dataType == DataType::tfloat)
return sizeof(float);
else if (m_elemType == ElementType::tdouble)
if (m_dataType == DataType::tdouble)
return sizeof(double);
else
LogicError("Error, elemtype is not defined for BinaryDataDeserializer.");
LogicError("Unsupported input data type %u.", m_dataType);
}
protected:
enum class DataType : unsigned char
{
tfloat = 0,
tdouble = 1,
// TODO:
// tbool = 2, 1 bit per value (one-hot data)
// tbyte = 3, 1 byte per value
};
virtual ~BinaryDataDeserialzer() = default;
void ReadName(FILE* file)
{
uint32_t len;
// read the name
CNTKBinaryFileHelper::ReadOrDie(&len, sizeof(len), 1, file);
vector<char> temp(len + 1 , '\0');
CNTKBinaryFileHelper::ReadOrDie(temp.data(), sizeof(char), len, file);
m_name = msra::strfun::utf16(temp.data());
}
void ReadDataType(FILE* file)
{
CNTKBinaryFileHelper::ReadOrDie(&m_dataType, sizeof(m_dataType), 1, file);
if (m_dataType> DataType::tdouble)
RuntimeError("Unsupported input data type %u.", m_dataType);
}
void ReadSampleSize(FILE* file)
{
CNTKBinaryFileHelper::ReadOrDie(&m_sampleDimension, sizeof(m_sampleDimension), 1, file);
}
struct DenseInputStreamBuffer : DenseSequenceData
{
// capacity = expected number of samples * sample size
const void* GetDataBuffer() override
{
return m_data;
}
void* m_data;
DataType m_dataType;
};
// In case of sparse input, we also need a vector of
// indices (one index for each input value) and a vector
// of NNZ counts (one for each sample).
struct SparseInputStreamBuffer : SparseSequenceData
{
SparseInputStreamBuffer()
@ -60,16 +122,13 @@ protected:
return m_data;
}
std::vector<IndexType> m_indicesBuffer;
void* m_data;
};
protected:
StorageType m_storageType;
ElementType m_elemType;
size_t m_numCols;
ElementType m_precision;
DataType m_dataType;
uint32_t m_sampleDimension;
wstring m_name;
};
typedef shared_ptr<BinaryDataDeserialzer> BinaryDataDeserializerPtr;
@ -77,156 +136,108 @@ typedef shared_ptr<BinaryDataDeserialzer> BinaryDataDeserializerPtr;
class DenseBinaryDataDeserializer : public BinaryDataDeserialzer
{
public:
DenseBinaryDataDeserializer(FILE* infile)
{
// We don't have to read the storage type. We know we're dense
m_storageType = StorageType::dense;
using BinaryDataDeserialzer::BinaryDataDeserialzer;
// Read the element type, note it's stored as an int32
int32_t elemType;
CNTKBinaryFileHelper::readOrDie(&elemType, sizeof(elemType), 1, infile);
if (elemType == 0)
m_elemType = ElementType::tfloat;
else if (elemType == 1)
m_elemType = ElementType::tdouble;
else
RuntimeError("Unsupported element type %d.", elemType);
// Read the number of columns
int32_t numCols;
CNTKBinaryFileHelper::readOrDie(&numCols, sizeof(numCols), 1, infile);
m_numCols = numCols;
}
virtual StorageType GetStorageType() override { return StorageType::dense; }
size_t GetSequenceDataForChunk(size_t numSequences, void* data, std::vector<SequenceDataPtr>& result)
{
size_t elemSize = GetElemSizeBytes();
size_t valueSize = SizeOfDataType();
result.resize(numSequences);
for (size_t c = 0; c < numSequences; c++)
size_t offset = 0;
for (size_t i = 0; i < numSequences; i++)
{
shared_ptr<DenseInputStreamBuffer> sequence = make_shared<DenseInputStreamBuffer>();
sequence->m_data = (char*)data + c*m_numCols*elemSize;
sequence->m_numberOfSamples = 1;
sequence->m_sampleLayout = std::make_shared<TensorShape>(m_numCols);
result[c] = sequence;
shared_ptr<DenseInputStreamBuffer> sequenceDataPtr = make_shared<DenseInputStreamBuffer>();
sequenceDataPtr->m_numberOfSamples = *(uint32_t*)((char*)data + offset);
offset += sizeof(uint32_t);
sequenceDataPtr->m_data = (char*)data + offset;
sequenceDataPtr->m_sampleLayout = GetSampleLayout();
sequenceDataPtr->m_elementType = m_precision;
result[i] = sequenceDataPtr;
offset += m_sampleDimension * valueSize * sequenceDataPtr->m_numberOfSamples;
}
// For dense, the number of bytes processed is just numRows * numCols * elemSize;
return numSequences * m_numCols * elemSize;
return offset;
}
};
class SparseBinaryDataDeserializer : public BinaryDataDeserialzer
{
public:
SparseBinaryDataDeserializer(FILE* infile)
SparseBinaryDataDeserializer(FILE* file, ElementType precision = ElementType::tfloat)
:BinaryDataDeserialzer(file, precision)
{
// Read the storage type. Currently we only support sparse_csc,
// but for future compatability allow it to be a parameter.
int32_t storageType;
CNTKBinaryFileHelper::readOrDie(&storageType, sizeof(storageType), 1, infile);
if (storageType == 0)
m_storageType = StorageType::sparse_csc;
else
RuntimeError("Unsupported storage type %d.", storageType);
// Read the element type, note it's stored as an int32
int32_t elemType;
CNTKBinaryFileHelper::readOrDie(&elemType, sizeof(elemType), 1, infile);
if (elemType== 0)
m_elemType = ElementType::tfloat;
else if (elemType == 1)
m_elemType = ElementType::tdouble;
else
RuntimeError("Unsupported element type %d.", elemType);
int32_t isSequence;
CNTKBinaryFileHelper::readOrDie(&isSequence, sizeof(isSequence), 1, infile);
if (isSequence == 0)
m_isSequence = false;
else if (isSequence == 1)
m_isSequence = true;
else
RuntimeError("Unsupported sequence type %d.", isSequence);
// Read the number of columns
int32_t numCols;
CNTKBinaryFileHelper::readOrDie(&numCols, sizeof(numCols), 1, infile);
m_numCols = numCols;
if (IndexType(m_sampleDimension) < 0)
{
RuntimeError("Sample dimension is too large for an IndexType value.");
}
}
bool IsSequence() override { return m_isSequence; }
virtual StorageType GetStorageType() override { return StorageType::sparse_csc; }
// The format of data is:
// int32_t: nnz for the entire chunk
// ElemType[nnz]: the values for the sparse sequences
// int32_t[nnz]: the row offsets for the sparse sequences
// int32_t[numSequences]: the column offsets for the sparse sequences
// sequence[numSequences], where each sequence consists of:
// uint32_t: numSamples
// uint32_t: nnz for the sequence
// ElemType[nnz]: the values for the sparse sequences
// int32_t[nnz]: the row offsets for the sparse sequences
// int32_t[numSamples]: sizes (nnz counts) for each sample in the sequence
size_t GetSequenceDataForChunk(size_t numSequences, void* data, std::vector<SequenceDataPtr>& result)
{
size_t elemSize = GetElemSizeBytes();
size_t offset = 0;
result.resize(numSequences);
// For sparse, the first int32_t is the number of nnz values in the entire set of sequences
int32_t totalNNz = *(int32_t*)data;
// the rest of this chunk
// Since we're not templating on ElemType, we use void for the values. Note that this is the only place
// this deserializer uses ElemType, the rest are int32_t for this deserializer.
void* values = (char*)data + sizeof(int32_t);
// Now the row offsets
int32_t* rowOffsets = (int32_t*)((char*)values + elemSize * totalNNz);
// Now the col offsets
int32_t* colOffsets = rowOffsets + totalNNz;
// Now we setup some helper members to process the chunk
for (size_t colIndex = 0; colIndex < numSequences; colIndex++)
for (size_t i = 0; i < numSequences; i++)
{
shared_ptr<SparseInputStreamBuffer> sequence = make_shared<SparseInputStreamBuffer>();
// We can't popuplate sequence->m_chunk here, so delay that for later
// We know the number of elements in all of the samples, it's just this:
sequence->m_totalNnzCount = colOffsets[colIndex + 1] - colOffsets[colIndex];
// The values array is already properly packed, so just use it.
sequence->m_data = values;
// The indices are correct (note they MUST BE IN INCREASING ORDER), but we will have to fix them up a
// little bit, for now just use them
sequence->m_indices = rowOffsets;
for (int32_t curRow = 0; curRow < sequence->m_totalNnzCount; curRow++)
{
// Get the sample for the current index
size_t sampleNum = rowOffsets[curRow] / m_numCols;
// The current sample might be OOB, if so, fill in the the missing ones.
while(sequence->m_nnzCounts.size() < sampleNum+1)
sequence->m_nnzCounts.push_back(0);
// Now that we have enough samples, increment the nnz for the sample
sequence->m_nnzCounts[sampleNum] += 1;
// Now that we've found it's sample, fix up the index.
rowOffsets[curRow] %= m_numCols;
}
sequence->m_numberOfSamples = (uint32_t)sequence->m_nnzCounts.size();
// update values, rowOffsets pointers
values = (char*)values + sequence->m_totalNnzCount * elemSize;
rowOffsets += sequence->m_totalNnzCount;
result[colIndex] = sequence;
shared_ptr<SparseInputStreamBuffer> sequenceDataPtr = make_shared<SparseInputStreamBuffer>();
offset += GetSequenceData((char*)data + offset, sequenceDataPtr);
sequenceDataPtr->m_sampleLayout = GetSampleLayout();
sequenceDataPtr->m_elementType = m_precision;
result[i] = sequenceDataPtr;
}
// For sparse, we compute how many bytes we processed
// From the header to this function, we see that is:
// sizeof(int32_t) + totalNNz * sizeof(ElemType) + totalNNz * sizeof(int32_t) + numSequences * sizeof(int32_t)
return sizeof(int32_t) + totalNNz * (elemSize + sizeof(int32_t)) + (numSequences + 1) * sizeof(int32_t);
return offset;
}
private:
bool m_isSequence;
size_t GetSequenceData(void* data, shared_ptr<SparseInputStreamBuffer>& sequence)
{
size_t valueSize = SizeOfDataType();
size_t offset = 0;
// The very first value in the buffer is the number of samples in this sequence.
sequence->m_numberOfSamples = *(uint32_t*)data;
offset += sizeof(uint32_t);
// Next is the total number of elements in all of the samples.
uint32_t nnz = *(uint32_t*)((char*)data + offset);
if (IndexType(nnz) < 0)
{
RuntimeError("NNZ count is too large for an IndexType value.");
}
sequence->m_totalNnzCount = nnz;
offset += sizeof(uint32_t);
// the rest of this sequence
// Since we're not templating on ElemType, we use void for the values. Note that this is the only place
// this deserializer uses ElemType, the rest are int32_t for this deserializer.
// The data is already properly packed, so just use it.
sequence->m_data = (char*)data + offset;
offset += valueSize * sequence->m_totalNnzCount;
// The indices are supposed to be correctly packed (i.e., in increasing order)
sequence->m_indices = (int32_t*)((char*)data + offset);
offset += sizeof(int32_t) * sequence->m_totalNnzCount;
int32_t* begin = (int32_t*)((char*)data + offset);
offset += sizeof(int32_t) * sequence->m_numberOfSamples;
int32_t* end = (int32_t*)((char*)data + offset);
sequence->m_nnzCounts.reserve(sequence->m_numberOfSamples);
sequence->m_nnzCounts.assign(begin, end);
return offset;
}
};

Просмотреть файл

@ -23,8 +23,8 @@ CNTKBinaryReader::CNTKBinaryReader(const ConfigParameters& config)
{
BinaryConfigHelper configHelper(config);
string log;
log += "Initializing CNTKBinaryReader";
std::stringstream log;
log << "Initializing CNTKBinaryReader";
try
{
m_deserializer = shared_ptr<IDataDeserializer>(new BinaryChunkDeserializer(configHelper));
@ -32,14 +32,16 @@ CNTKBinaryReader::CNTKBinaryReader(const ConfigParameters& config)
if (configHelper.ShouldKeepDataInMemory())
{
m_deserializer = shared_ptr<IDataDeserializer>(new ChunkCache(m_deserializer));
log += " | keeping data in memory";
log << " | keeping data in memory";
}
size_t window = configHelper.GetRandomizationWindow();
if (window > 0)
{
// Verbosity is a general config parameter, not specific to the binary format reader.
log += " | randomizing with window: " + (int)window;
log << " | randomizing with window: "
<< window
<< configHelper.UseSampleBasedRandomizationWindow() ? " samples" : " chunks";
int verbosity = config(L"verbosity", 0);
m_sequenceEnumerator = make_shared<BlockRandomizer>(
verbosity, /* verbosity */
@ -52,7 +54,7 @@ CNTKBinaryReader::CNTKBinaryReader(const ConfigParameters& config)
}
else
{
log += " | without randomization";
log << " | without randomization";
m_sequenceEnumerator = std::make_shared<NoRandomizer>(m_deserializer);
}
@ -64,7 +66,7 @@ CNTKBinaryReader::CNTKBinaryReader(const ConfigParameters& config)
RuntimeError("CNTKBinaryReader: While reading '%ls': %s", configHelper.GetFilePath().c_str(), e.what());
}
if (configHelper.GetTraceLevel() > 2)
fprintf(stderr, "%s\n", log.c_str());
fprintf(stderr, "%s\n", log.str().c_str());
}
} } }

Просмотреть файл

@ -34,6 +34,31 @@ namespace Microsoft { namespace MSR { namespace CNTK {
class CNTKBinaryFileHelper
{
public:
static const uint64_t MAGIC_NUMBER = 0x636e746b5f62696eU;
static void FindMagicOrDie(FILE* f, wstring name) {
// Read the magic number and make sure we're given a proper CBF file.
uint64_t number;
ReadOrDie(&number, sizeof(number), 1, f);
if (number != MAGIC_NUMBER)
RuntimeError("The input (%S) is not a valid CNTK binary format file.",
name.c_str());
}
static uint32_t GetVersionNumber(FILE* f) {
uint32_t versionNumber;
ReadOrDie(&versionNumber, sizeof(versionNumber), 1, f);
return versionNumber;
}
static int64_t GetHeaderOffset(FILE* f) {
// Seek to the end of file -8 bytes to find the offset of the header.
SeekOrDie(f, -int64_t(sizeof(int64_t)), SEEK_END);
int64_t headerOffset;
ReadOrDie(&headerOffset, sizeof(headerOffset), 1, f);
return headerOffset;
}
static FILE* openOrDie(const string& pathname, const char* mode)
{
FILE* f = fopen(pathname.c_str(), mode);
@ -42,7 +67,7 @@ public:
return f;
}
static FILE* openOrDie(const wstring& pathname, const wchar_t* mode)
static FILE* OpenOrDie(const wstring& pathname, const wchar_t* mode)
{
FILE* f = _wfopen(pathname.c_str(), mode);
if (!f)
@ -50,14 +75,14 @@ public:
return f;
}
static void closeOrDie(FILE* f)
static void CloseOrDie(FILE* f)
{
int rc = fclose(f);
if (rc != 0)
RuntimeError("Error closing: %s.", strerror(errno));
}
static void seekOrDie(FILE* f, int64_t offset, int mode)
static void SeekOrDie(FILE* f, int64_t offset, int mode)
{
int rc;
#ifdef __WINDOWS__
@ -69,7 +94,7 @@ public:
RuntimeError("Error seeking: %s.", strerror(errno));
}
static int64_t tellOrDie(FILE* f)
static int64_t TellOrDie(FILE* f)
{
size_t rc;
#ifdef __WINDOWS__
@ -82,7 +107,7 @@ public:
return rc;
}
static void readOrDie(void* ptr, size_t size, size_t count, FILE* f)
static void ReadOrDie(void* ptr, size_t size, size_t count, FILE* f)
{
size_t rc;
rc = fread(ptr, size, count, f);

Просмотреть файл

@ -1431,16 +1431,22 @@ Test module "ReaderTests" has passed with:
95 test cases out of 95 passed
16971350 assertions out of 16971350 passed
Test case "ReaderTestSuite/CNTKBinaryReader_sparse_seq" has passed with:
1 assertion out of 1 passed
Test case "ReaderTestSuite/CNTKBinaryReader_Simple_sparse" has passed with:
1 assertion out of 1 passed
Test case "ReaderTestSuite/CNTKBinaryReader_Simple_dense" has passed with:
1 assertion out of 1 passed
Test case "ReaderTestSuite/CNTKBinaryReader_Simple_dense2" has passed with:
Test case "ReaderTestSuite/CNTKBinaryReader_MNIST_dense" has passed with:
1 assertion out of 1 passed
Test case "ReaderTestSuite/CNTKBinaryReader_10x10_dense" has passed with:
1 assertion out of 1 passed
Test case "ReaderTestSuite/CNTKBinaryReader_50x20_jagged_sequences_dense" has passed with:
1 assertion out of 1 passed
Test case "ReaderTestSuite/CNTKBinaryReader_10x10_sparse" has passed with:
1 assertion out of 1 passed
Test case "ReaderTestSuite/CNTKBinaryReader_50x20_jagged_sequences_sparse" has passed with:
1 assertion out of 1 passed
Test case "ReaderTestSuite/CNTKTextFormatReader_Simple_dense" has passed with:

Просмотреть файл

@ -9,9 +9,7 @@
using namespace Microsoft::MSR::CNTK;
namespace Microsoft { namespace MSR { namespace CNTK {
namespace Test {
namespace Microsoft { namespace MSR { namespace CNTK { namespace Test {
struct CNTKBinaryReaderFixture : ReaderFixture
{
@ -23,75 +21,116 @@ struct CNTKBinaryReaderFixture : ReaderFixture
BOOST_FIXTURE_TEST_SUITE(ReaderTestSuite, CNTKBinaryReaderFixture)
BOOST_AUTO_TEST_CASE(CNTKBinaryReader_sparse_seq)
{
HelperRunReaderTest<float>(
testDataPath() + "/Config/CNTKBinaryReader/test.cntk",
testDataPath() + "/Control/CNTKBinaryReader/Simple_sparse_seq.txt",
testDataPath() + "/Control/CNTKBinaryReader/Simple_sparse_seq_Output.txt",
"SparseSeq",
"reader",
1500, // epoch size
250, // mb size
1, // num epochs
2,
2,
0,
1, true, false, false);
};
BOOST_AUTO_TEST_CASE(CNTKBinaryReader_Simple_sparse)
{
HelperRunReaderTest<float>(
testDataPath() + "/Config/CNTKBinaryReader/test.cntk",
testDataPath() + "/Control/CNTKBinaryReader/Simple_sparse.txt",
testDataPath() + "/Control/CNTKBinaryReader/Simple_sparse_Output.txt",
"Sparse",
"reader",
1600, // epoch size
250, // mb size
1, // num epochs
2,
2,
0,
1, true, false, false);
};
BOOST_AUTO_TEST_CASE(CNTKBinaryReader_Simple_dense)
{
HelperRunReaderTest<float>(
testDataPath() + "/Config/CNTKBinaryReader/test.cntk",
testDataPath() + "/Control/CNTKBinaryReader/Simple_dense.txt",
testDataPath() + "/Control/CNTKTextFormatReader/Simple_dense.txt",
testDataPath() + "/Control/CNTKBinaryReader/Simple_dense_Output.txt",
"Simple",
"reader",
1600, // epoch size
1000, // epoch size
250, // mb size
1, // num epochs
4,
10, // num epochs
1,
1,
0,
0,
1, false, false, false);
1);
};
BOOST_AUTO_TEST_CASE(CNTKBinaryReader_Simple_dense2)
BOOST_AUTO_TEST_CASE(CNTKBinaryReader_MNIST_dense)
{
HelperRunReaderTest<double>(
testDataPath() + "/Config/CNTKBinaryReader/test.cntk",
testDataPath() + "/Control/CNTKTextFormatReader/MNIST_dense.txt",
testDataPath() + "/Control/CNTKBinaryReader/MNIST_dense_Output.txt",
"MNIST",
"reader",
1000, // epoch size
1000, // mb size
1, // num epochs
1,
1,
0,
1);
};
// 10 sequences with 10 samples each (no randomization)
BOOST_AUTO_TEST_CASE(CNTKBinaryReader_10x10_dense)
{
HelperRunReaderTest<float>(
testDataPath() + "/Config/CNTKBinaryReader/test.cntk",
testDataPath() + "/Control/CNTKBinaryReader/Simple_dense_312.txt",
testDataPath() + "/Control/CNTKBinaryReader/Simple_dense_312_Output.txt",
"Simple",
testDataPath() + "/Control/CNTKTextFormatReader/10x10_dense.txt",
testDataPath() + "/Control/CNTKBinaryReader/10x10_dense_Output.txt",
"10x10_dense",
"reader",
1600, // epoch size
312, // mb size
1, // num epochs
4,
100, // epoch size
100, // mb size
1, // num epochs
1,
0, // no labels
0,
0,
1, false, false, false);
1);
};
// 50 sequences with up to 20 samples each (508 samples in total)
BOOST_AUTO_TEST_CASE(CNTKBinaryReader_50x20_jagged_sequences_dense)
{
HelperRunReaderTest<double>(
testDataPath() + "/Config/CNTKBinaryReader/test.cntk",
testDataPath() + "/Control/CNTKTextFormatReader/50x20_jagged_sequences_dense.txt",
testDataPath() + "/Control/CNTKBinaryReader/50x20_jagged_sequences_dense_Output.txt",
"50x20_jagged_sequences_dense",
"reader",
508, // epoch size
508, // mb size
1, // num epochs
1,
0,
0,
1);
};
// 10 sequences with 10 samples each (no randomization)
BOOST_AUTO_TEST_CASE(CNTKBinaryReader_10x10_sparse)
{
HelperRunReaderTest<double>(
testDataPath() + "/Config/CNTKBinaryReader/test.cntk",
testDataPath() + "/Control/CNTKTextFormatReader/10x10_sparse.txt",
testDataPath() + "/Control/CNTKBinaryReader/10x10_sparse_Output.txt",
"10x10_sparse",
"reader",
100, // epoch size
100, // mb size
1, // num epochs
1,
0, // no labels
0,
1,
true);
};
// 50 sequences with up to 20 samples each (536 samples in total)
BOOST_AUTO_TEST_CASE(CNTKBinaryReader_50x20_jagged_sequences_sparse)
{
HelperRunReaderTest<float>(
testDataPath() + "/Config/CNTKBinaryReader/test.cntk",
testDataPath() + "/Control/CNTKTextFormatReader/50x20_jagged_sequences_sparse.txt",
testDataPath() + "/Control/CNTKBinaryReader/50x20_jagged_sequences_sparse_Output.txt",
"50x20_jagged_sequences_sparse",
"reader",
564, // epoch size
564, // mb size
1, // num epochs
1,
0,
0,
1,
true);
};
BOOST_AUTO_TEST_SUITE_END()
} } } }

Просмотреть файл

@ -51,31 +51,6 @@ public:
namespace Test {
// identical to 'sort -o filename filename'
void SortLinesInFile(string filename, size_t expectedNumLines = 1)
{
vector<string> content;
content.reserve(expectedNumLines);
ifstream ifstream(filename);
string line;
while (getline(ifstream, line))
{
content.push_back(line);
}
ifstream.close();
sort(content.begin(), content.end());
ofstream ofstream(filename);
copy(content.begin(), content.end(), ostream_iterator<string>(ofstream, "\n"));
ofstream.close();
}
struct CNTKTextFormatReaderFixture : ReaderFixture
{
CNTKTextFormatReaderFixture()

Просмотреть файл

@ -13,6 +13,30 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Test {
const double relError = 1e-5f;
// identical to 'sort -o filename filename'
inline void SortLinesInFile(string filename, size_t expectedNumLines = 1)
{
vector<string> content;
content.reserve(expectedNumLines);
ifstream ifstream(filename);
string line;
while (getline(ifstream, line))
{
content.push_back(line);
}
ifstream.close();
sort(content.begin(), content.end());
ofstream ofstream(filename);
copy(content.begin(), content.end(), ostream_iterator<string>(ofstream, "\n"));
ofstream.close();
}
struct ReaderFixture
{
// This fixture sets up paths so the tests can assume the right location for finding the configuration

Просмотреть файл

@ -1,50 +1,91 @@
# deviceId = -1 for CPU, >= 0 for GPU devices
deviceId = -1
SparseSeq = [
precision = "float"
reader = [
readerType = "CNTKBinaryReader"
file = "sparseseqoutput.bin"
input = [
features1 = [ alias="a" ]
features2 = [ alias="b" ]
labels1 = [ alias="c" ]
labels2 = [ alias="d" ]
]
randomize = false
]
]
Sparse = [
precision = "float"
reader = [
readerType = "CNTKBinaryReader"
file = "sparseoutput.bin"
input = [
features1 = [ alias="a" ]
features2 = [ alias="b" ]
labels1 = [ alias="c" ]
labels2 = [ alias="d" ]
]
randomize = false
]
]
Simple = [
precision = "float"
reader = [
readerType = "CNTKBinaryReader"
file = "simple.bin"
file = "Simple_dense.bin"
randomize = false
]
]
MNIST = [
precision = "double"
reader = [
readerType = "CNTKBinaryReader"
file = "MNIST_dense.bin" # contains half a dozen chunks with ca. 400 KB in each
randomize = false
keepDataInMemory = true
]
]
10x10_dense = [
precision = "float"
reader = [
readerType = "CNTKBinaryReader"
# Training file contains ten sequence with ten samples each
file = "10x10_dense.bin"
randomize = false
]
]
50x20_jagged_sequences_dense = [
precision = "double"
reader = [
readerType = "CNTKBinaryReader"
# Training file contains 50 sequence with *up to* 20 samples each
file = "50x20_jagged_sequences_dense.bin"
randomize = false
]
]
10x10_sparse = [
precision = "double"
reader = [
readerType = "CNTKBinaryReader"
# Training file contains ten sequences with ten samples each
file = "10x10_sparse.bin"
randomize = false
]
]
50x20_jagged_sequences_sparse = [
precision = "float"
reader = [
readerType = "CNTKBinaryReader"
# Training file contains 50 sequence with *up to* 20 samples each
file = "50x20_jagged_sequences_sparse.bin"
randomize = false
]
]
100x100x3_randomize_auto = [
precision = "double"
reader = [
readerType = "CNTKBinaryReader"
# Training file contains 100 sequence with *up to* 100 samples
# in each of 3 inputs.
file = "100x100x3_jagged_sequences_dense.bin"
randomize = true
]
]
5_inputs_100x10_jagged_mixed = [
precision = "float"
reader = [
readerType = "CNTKBinaryReader"
# Training file contains 100 sequence with *up to* 100 samples
# in each of 3 inputs.
file = "5_inputs_100x10_jagged_mixed.bin"
input = [
features1 = [ alias="a" ]
features2 = [ alias="b" ]
features3 = [ alias="c" ]
features4 = [ alias="d" ]
features5 = [ alias="e" ]
]
randomize = false
]
]
]

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Двоичный файл не отображается.

Двоичные данные
Tests/UnitTests/ReaderTests/Data/CNTKBinaryReader/10x10_dense.bin Normal file

Двоичный файл не отображается.

Двоичные данные
Tests/UnitTests/ReaderTests/Data/CNTKBinaryReader/10x10_sparse.bin Normal file

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичные данные
Tests/UnitTests/ReaderTests/Data/CNTKBinaryReader/MNIST_dense.bin Normal file

Двоичный файл не отображается.

Двоичные данные
Tests/UnitTests/ReaderTests/Data/CNTKBinaryReader/Simple_dense.bin Normal file

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Просмотреть файл

@ -143,6 +143,10 @@
<Text Include="Config\HTKMLFReaderSimpleDataLoop9_Config.cntk" />
<Text Include="Config\ImageReaderSimple_Config.cntk" />
<Text Include="Config\UCIFastReaderSimpleDataLoop_Config.cntk" />
<Text Include="Control\CNTKBinaryReader\Simple_dense.txt" />
<Text Include="Control\CNTKBinaryReader\Simple_dense_312.txt" />
<Text Include="Control\CNTKBinaryReader\Simple_sparse.txt" />
<Text Include="Control\CNTKBinaryReader\Simple_sparse_seq.txt" />
<Text Include="Control\CNTKTextFormatReader\100x100x3_jagged_sequences_dense_sorted.txt" />
<Text Include="Control\CNTKTextFormatReader\100x1_1_dense.txt" />
<Text Include="Control\CNTKTextFormatReader\100x1_2_dense.txt" />
@ -218,6 +222,7 @@
<Image Include="Data\images\red.jpg" />
</ItemGroup>
<ItemGroup>
<None Include="Config\CNTKBinaryReader\test.cntk" />
<None Include="Config\CNTKTextFormatReader\dense.cntk" />
<None Include="Config\CNTKTextFormatReader\edge_cases.cntk" />
<None Include="Config\CNTKTextFormatReader\sparse.cntk" />
@ -244,6 +249,9 @@
<None Include="Config\ImageReaderLabelOutOfRange_Config.cntk" />
<None Include="Config\ImageReaderMultiView_Config.cntk" />
<None Include="Config\ImageReaderZip_Config.cntk" />
<None Include="Data\CNTKBinaryReader\simple.bin" />
<None Include="Data\CNTKBinaryReader\sparseoutput.bin" />
<None Include="Data\CNTKBinaryReader\sparseseqoutput.bin" />
<None Include="Data\images\chunk0.zip" />
<None Include="Data\images\chunk1.zip" />
<None Include="Data\images\simple.zip" />

Просмотреть файл

@ -52,6 +52,15 @@
<Filter Include="Config\HTKDeserializers">
<UniqueIdentifier>{b2a5f515-1b00-4f19-bd19-8e3fd7eb6872}</UniqueIdentifier>
</Filter>
<Filter Include="Config\CNTKBinaryReader">
<UniqueIdentifier>{46df705e-a35d-4bb8-9272-92c7cb293079}</UniqueIdentifier>
</Filter>
<Filter Include="Control\CNTKBinaryReader">
<UniqueIdentifier>{4194f906-450f-4112-841d-c80b3f90081e}</UniqueIdentifier>
</Filter>
<Filter Include="Data\CNTKBinaryReader">
<UniqueIdentifier>{4637ff35-9ba4-4850-bfe5-d534c5920388}</UniqueIdentifier>
</Filter>
</ItemGroup>
<ItemGroup>
<Text Include="Data\ImageReaderSimple_map.txt">
@ -324,6 +333,18 @@
<Text Include="Data\ImageAndTextReaderSimple_map.txt">
<Filter>Data</Filter>
</Text>
<Text Include="Control\CNTKBinaryReader\Simple_dense.txt">
<Filter>Control\CNTKBinaryReader</Filter>
</Text>
<Text Include="Control\CNTKBinaryReader\Simple_dense_312.txt">
<Filter>Control\CNTKBinaryReader</Filter>
</Text>
<Text Include="Control\CNTKBinaryReader\Simple_sparse.txt">
<Filter>Control\CNTKBinaryReader</Filter>
</Text>
<Text Include="Control\CNTKBinaryReader\Simple_sparse_seq.txt">
<Filter>Control\CNTKBinaryReader</Filter>
</Text>
</ItemGroup>
<ItemGroup>
<Image Include="Data\images\black.jpg">
@ -430,6 +451,18 @@
<None Include="Config\ImageDeserializers.cntk">
<Filter>Config</Filter>
</None>
<None Include="Config\CNTKBinaryReader\test.cntk">
<Filter>Config\CNTKBinaryReader</Filter>
</None>
<None Include="Data\CNTKBinaryReader\simple.bin">
<Filter>Data\CNTKBinaryReader</Filter>
</None>
<None Include="Data\CNTKBinaryReader\sparseoutput.bin">
<Filter>Data\CNTKBinaryReader</Filter>
</None>
<None Include="Data\CNTKBinaryReader\sparseseqoutput.bin">
<Filter>Data\CNTKBinaryReader</Filter>
</None>
</ItemGroup>
<ItemGroup>
<Xml Include="Data\ImageNet1K_intensity.xml">