removed template parameter ElemType from (I)DataReader and (I)DataWriter

This commit is contained in:
Frank Seide 2016-02-28 19:01:07 -08:00
Родитель 5bdc86e5de
Коммит e54b352822
60 изменённых файлов: 390 добавлений и 609 удалений

Просмотреть файл

@ -298,12 +298,12 @@ status open
\begin_layout Plain Layout
extern "C" DATAREADER_API void GetReaderF(IDataReader<float>** preader);
extern "C" DATAREADER_API void GetReaderF(IDataReader** preader);
\end_layout
\begin_layout Plain Layout
extern "C" DATAREADER_API void GetReaderD(IDataReader<double>** preader);
extern "C" DATAREADER_API void GetReaderD(IDataReader** preader);
\end_layout
\end_inset
@ -319,12 +319,12 @@ status open
\begin_layout Plain Layout
extern "C" DATAWRITER_API void GetWriterF(IDataWriter<float>** pwriter);
extern "C" DATAWRITER_API void GetWriterF(IDataWriter** pwriter);
\end_layout
\begin_layout Plain Layout
extern "C" DATAWRITER_API void GetWriterD(IDataWriter<double>** pwriter);
extern "C" DATAWRITER_API void GetWriterD(IDataWriter** pwriter);
\end_layout
\end_inset
@ -440,14 +440,13 @@ ochSamples=requestDataSize) = 0;
\begin_layout Plain Layout
virtual const std::map<typename LabelIdType, typename LabelType>& GetLabelMapp
virtual const std::map<LabelIdType, LabelType>& GetLabelMapp
ing(const std::wstring& sectionName) = 0;
\end_layout
\begin_layout Plain Layout
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<t
ypename LabelIdType, typename LabelType>& labelMapping) = 0;
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, LabelType>& labelMapping) = 0;
\end_layout
\begin_layout Plain Layout
@ -1097,8 +1096,7 @@ public:
\begin_layout Plain Layout
virtual void SaveMapping(std::wstring saveId, const std::map<typename
LabelIdType, typename LabelType>& labelMapping) = 0;
virtual void SaveMapping(std::wstring saveId, const std::map<LabelIdType, LabelType>& labelMapping) = 0;
\end_layout
\begin_layout Plain Layout

Просмотреть файл

@ -321,7 +321,7 @@ test = [
action = "eval"
# correspond to the number of words/characteres to train in a minibatch
minibatchSize = 8192 # choose as large as memory allows for maximum concurrency
minibatchSize = 8192 # choose as large as memory allows for maximum GPU concurrency
# need to be small since models are updated for each minibatch
traceLevel = 1
epochSize = 0

Просмотреть файл

@ -307,8 +307,8 @@ $(BINARY_READER): $(BINARYREADER_OBJ) | $(CNTKMATH_LIB)
########################################
HTKMLFREADER_SRC =\
$(SOURCEDIR)/Readers/HTKMLFReader/DataReader.cpp \
$(SOURCEDIR)/Readers/HTKMLFReader/DataWriter.cpp \
$(SOURCEDIR)/Readers/HTKMLFReader/Exports.cpp \
$(SOURCEDIR)/Readers/HTKMLFReader/DataWriterLocal.cpp \
$(SOURCEDIR)/Readers/HTKMLFReader/HTKMLFReader.cpp \
$(SOURCEDIR)/Readers/HTKMLFReader/HTKMLFWriter.cpp \
@ -348,6 +348,7 @@ $(LMSEQUENCEREADER): $(LMSEQUENCEREADER_OBJ) | $(CNTKMATH_LIB)
LUSEQUENCEREADER_SRC =\
$(SOURCEDIR)/Readers/LUSequenceReader/Exports.cpp \
$(SOURCEDIR)/Readers/LUSequenceReader/DataWriterLocal.cpp \
$(SOURCEDIR)/Readers/LUSequenceReader/LUSequenceParser.cpp \
$(SOURCEDIR)/Readers/LUSequenceReader/LUSequenceReader.cpp \

Просмотреть файл

@ -43,7 +43,7 @@ using namespace Microsoft::MSR::CNTK;
// ===========================================================================
template <typename ElemType>
static void DoEvalBase(const ConfigParameters& config, IDataReader<ElemType>& reader)
static void DoEvalBase(const ConfigParameters& config, IDataReader& reader)
{
DEVICEID_TYPE deviceId = DeviceFromConfig(config);
ConfigArray minibatchSize = config(L"minibatchSize", "40960");
@ -78,9 +78,9 @@ void DoEval(const ConfigParameters& config)
ConfigParameters readerConfig(config(L"reader"));
readerConfig.Insert("traceLevel", config(L"traceLevel", "0"));
DataReader<ElemType> testDataReader(readerConfig);
DataReader testDataReader(readerConfig);
DoEvalBase(config, testDataReader);
DoEvalBase<ElemType>(config, testDataReader);
}
template void DoEval<double>(const ConfigParameters& config);
@ -125,7 +125,7 @@ void DoCrossValidate(const ConfigParameters& config)
std::vector<std::vector<double>> cvErrorResults;
std::vector<std::wstring> cvModels;
DataReader<ElemType> cvDataReader(readerConfig);
DataReader cvDataReader(readerConfig);
bool finalModelEvaluated = false;
for (size_t i = cvInterval[0]; i <= cvInterval[2]; i += cvInterval[1])
@ -206,7 +206,7 @@ void DoWriteOutput(const ConfigParameters& config)
readerConfig.Insert("traceLevel", config(L"traceLevel", "0"));
readerConfig.Insert("randomize", "None"); // we don't want randomization when output results
DataReader<ElemType> testDataReader(readerConfig);
DataReader testDataReader(readerConfig);
DEVICEID_TYPE deviceId = DeviceFromConfig(config);
ConfigArray minibatchSize = config(L"minibatchSize", "2048");
@ -244,7 +244,7 @@ void DoWriteOutput(const ConfigParameters& config)
{
ConfigParameters writerConfig(config(L"writer"));
bool bWriterUnittest = writerConfig(L"unittest", "false");
DataWriter<ElemType> testDataWriter(writerConfig);
DataWriter testDataWriter(writerConfig);
writer.WriteOutput(testDataReader, mbSize[0], testDataWriter, outputNodeNamesVector, epochSize, bWriterUnittest);
}
else if (config.Exists("outputPath"))

Просмотреть файл

@ -85,8 +85,7 @@ void DoCreateLabelMap(const ConfigParameters& config)
}
fprintf(stderr, "CreateLabelMap: Creating the mapping file '%s' \n", labelMappingFile.c_str());
DataReader<ElemType> dataReader(readerConfig);
DataReader dataReader(readerConfig);
dataReader.StartMinibatchLoop(minibatchSize, 0, requestDataSize);
int count = 0;
while (dataReader.GetMinibatch(matrices))

Просмотреть файл

@ -155,11 +155,11 @@ void DoTrain(const ConfigRecordType& config)
RuntimeError("No network builder found in the config file. NDLNetworkBuilder or SimpleNetworkBuilde must be specified");
}
auto dataReader = CreateObject<DataReader<ElemType>>(config, L"reader");
auto dataReader = CreateObject<DataReader>(config, L"reader");
shared_ptr<DataReader<ElemType>> cvDataReader;
shared_ptr<DataReader> cvDataReader;
if (config.Exists(L"cvReader"))
cvDataReader = CreateObject<DataReader<ElemType>>(config, L"cvReader");
cvDataReader = CreateObject<DataReader>(config, L"cvReader");
shared_ptr<SGD<ElemType>> optimizer;
if (config.Exists(L"optimizer"))
@ -225,15 +225,15 @@ void DoAdapt(const ConfigParameters& config)
ConfigParameters readerConfig(config(L"reader"));
readerConfig.Insert("traceLevel", config(L"traceLevel", "0"));
DataReader<ElemType>* dataReader = new DataReader<ElemType>(readerConfig);
auto dataReader = make_shared<DataReader>(readerConfig);
DataReader<ElemType>* cvDataReader = nullptr;
shared_ptr<DataReader> cvDataReader;
ConfigParameters cvReaderConfig(config(L"cvReader", L""));
if (cvReaderConfig.size() != 0)
{
cvReaderConfig.Insert("traceLevel", config(L"traceLevel", "0"));
cvDataReader = new DataReader<ElemType>(cvReaderConfig);
cvDataReader = make_shared<DataReader>(cvReaderConfig);
}
wstring origModelFileName = config(L"origModelFileName", L"");
@ -241,10 +241,7 @@ void DoAdapt(const ConfigParameters& config)
SGD<ElemType> sgd(configSGD);
sgd.Adapt(origModelFileName, refNodeName, dataReader, cvDataReader, deviceId, makeMode);
delete dataReader;
delete cvDataReader;
sgd.Adapt(origModelFileName, refNodeName, dataReader.get(), cvDataReader.get(), deviceId, makeMode);
}
template void DoAdapt<float>(const ConfigParameters& config);

Просмотреть файл

@ -612,25 +612,17 @@ int wmainOldCNTKConfig(int argc, wchar_t* argv[]) // called from wmain which is
std::string type = config(L"precision", "float");
// accept old precision key for backward compatibility
if (config.Exists("type"))
{
type = config(L"type", "float");
}
InvalidArgument("CNTK: Use of 'type' parameter is deprecated, it is called 'precision' now.");
fprintf(stderr, "\nPrecision = \"%s\"\n", type.c_str());
if (type == "float")
{
DoCommands<float>(config);
}
else if (type == "double")
{
DoCommands<double>(config);
}
else
{
RuntimeError("CNTK: Invalid precision string: \"%s\", must be \"float\" or \"double\"", type.c_str());
}
// still here , write a DoneFile if necessary
// if completed then write a DoneFile if requested
if (!DoneFile.empty())
{
FILE* fp = fopenOrDie(DoneFile.c_str(), L"w");

Просмотреть файл

@ -107,7 +107,7 @@ void TestReader(const ConfigParameters& configBase)
epochSize = requestDataSize;
}
DataReader<ElemType> dataReader(readerConfig);
DataReader dataReader(readerConfig);
// get names of features and labels
std::vector<std::wstring> featureNames;
@ -171,7 +171,7 @@ void TestSequenceReader(const ConfigParameters& configBase)
std::vector<std::wstring> labelNames;
GetFileConfigNames(readerConfig, featureNames, labelNames);
DataReader<ElemType> dataReader(readerConfig);
DataReader dataReader(readerConfig);
// get names of features and labels
std::vector<std::wstring> files;

Просмотреть файл

@ -15,32 +15,24 @@
#include "DataReader.h"
#include "Config.h"
#include "ScriptableObjects.h"
#include <string>
using namespace std;
namespace Microsoft { namespace MSR { namespace CNTK {
template <class ElemType>
std::string GetReaderName(ElemType)
static const char* GetReaderName(const string& precision)
{
return std::string();
}
template <>
std::string GetReaderName(float)
{
std::string name = "GetReaderF";
return name;
}
template <>
std::string GetReaderName(double)
{
std::string name = "GetReaderD";
return name;
if (precision == "float")
return "GetReaderF";
else if (precision == "double")
return "GetReaderD";
else
InvalidArgument("DataReader: The 'precision' parameter must be 'float' or 'double'.");
}
template <class ElemType>
template <class ConfigRecordType>
void DataReader<ElemType>::InitFromConfig(const ConfigRecordType& /*config*/)
void DataReader::InitFromConfig(const ConfigRecordType& /*config*/)
{
RuntimeError("Init shouldn't be called, use constructor");
// not implemented, calls the underlying class instead
@ -48,8 +40,7 @@ void DataReader<ElemType>::InitFromConfig(const ConfigRecordType& /*config*/)
// Destroy - cleanup and remove this class
// NOTE: this destroys the object, and it can't be used past this point
template <class ElemType>
void DataReader<ElemType>::Destroy()
void DataReader::Destroy()
{
// newer code that explicitly place multiple streams for inputs
foreach_index (i, m_ioNames) // inputNames should map to node names
@ -60,14 +51,15 @@ void DataReader<ElemType>::Destroy()
// DataReader Constructor
// options - [in] string of options (i.e. "-windowsize:11 -addenergy") data reader specific
template <class ElemType>
template <class ConfigRecordType>
DataReader<ElemType>::DataReader(const ConfigRecordType& config)
DataReader::DataReader(const ConfigRecordType& config)
{
typedef void (*GetReaderProc)(IDataReader<ElemType>** preader);
typedef void (*GetReaderProc)(IDataReader** preader);
assert(m_dataReaders.empty());
string precision = config(L"precision", "float");
bool hasMultipleReaders = config.Exists(L"readers");
if (hasMultipleReaders)
{
@ -77,7 +69,7 @@ DataReader<ElemType>::DataReader(const ConfigRecordType& config)
{
const ConfigRecordType& thisIO = config(ioName);
// get the name for the reader we want to use, default to UCIFastReader
GetReaderProc getReaderProc = (GetReaderProc) Plugin::Load(thisIO(L"readerType", L"UCIFastReader"), GetReaderName((ElemType) 0));
GetReaderProc getReaderProc = (GetReaderProc) Plugin::Load(thisIO(L"readerType", L"UCIFastReader"), GetReaderName(precision));
m_ioNames.push_back(ioName);
assert(getReaderProc != nullptr);
getReaderProc(&m_dataReaders[ioName]); // instantiates the reader with the default constructor (no config processed at this point)
@ -88,7 +80,7 @@ DataReader<ElemType>::DataReader(const ConfigRecordType& config)
wstring ioName = L"ioName";
// backward support to use only one type of data reader
// get the name for the reader we want to use, default to UCIFastReader
GetReaderProc getReaderProc = (GetReaderProc) Plugin::Load(config(L"readerType", L"UCIFastReader"), GetReaderName((ElemType) 0));
GetReaderProc getReaderProc = (GetReaderProc)Plugin::Load(config(L"readerType", L"UCIFastReader"), GetReaderName(precision));
m_ioNames.push_back(ioName);
assert(getReaderProc != nullptr);
getReaderProc(&m_dataReaders[ioName]);
@ -109,14 +101,11 @@ DataReader<ElemType>::DataReader(const ConfigRecordType& config)
}
}
template DataReader<float>::DataReader(const ConfigParameters&);
template DataReader<double>::DataReader(const ConfigParameters&);
template DataReader<float>::DataReader(const ScriptableObjects::IConfigRecord&);
template DataReader<double>::DataReader(const ScriptableObjects::IConfigRecord&);
template DataReader::DataReader(const ConfigParameters&);
template DataReader::DataReader(const ScriptableObjects::IConfigRecord&);
// destructor - cleanup temp files, etc.
template <class ElemType>
DataReader<ElemType>::~DataReader()
DataReader::~DataReader()
{
// free up resources
for (size_t i = 0; i < m_ioNames.size(); i++)
@ -127,16 +116,14 @@ DataReader<ElemType>::~DataReader()
// mbSize - [in] size of the minibatch (number of frames, etc.)
// epoch - [in] epoch number for this loop
// requestedEpochSamples - [in] number of samples to randomize, defaults to requestDataSize which uses the number of samples there are in the dataset
template <class ElemType>
void DataReader<ElemType>::StartMinibatchLoop(size_t mbSize, size_t epoch, size_t requestedEpochSamples)
void DataReader::StartMinibatchLoop(size_t mbSize, size_t epoch, size_t requestedEpochSamples)
{
for (size_t i = 0; i < m_ioNames.size(); i++)
m_dataReaders[m_ioNames[i]]->StartMinibatchLoop(mbSize, epoch, requestedEpochSamples);
}
//SupportsDistributedMBRead - Tells if the reader supports distributed minibatch reading for parallel training
template <class ElemType>
bool DataReader<ElemType>::SupportsDistributedMBRead() const
bool DataReader::SupportsDistributedMBRead() const
{
bool supportsDistributedMBRead = true;
for (size_t i = 0; i < m_ioNames.size(); i++)
@ -156,8 +143,7 @@ bool DataReader<ElemType>::SupportsDistributedMBRead() const
// subsetNum - [in] the subset number of the current node in a group of parallel training nodes
// numSubsets - [in] total number of nodes participating in the parallel training
// requestedEpochSamples - [in] number of samples to randomize, defaults to requestDataSize which uses the number of samples there are in the dataset
template <class ElemType>
void DataReader<ElemType>::StartDistributedMinibatchLoop(size_t mbSize, size_t epoch, size_t subsetNum, size_t numSubsets, size_t requestedEpochSamples /* = requestDataSize*/)
void DataReader::StartDistributedMinibatchLoop(size_t mbSize, size_t epoch, size_t subsetNum, size_t numSubsets, size_t requestedEpochSamples /* = requestDataSize*/)
{
for (size_t i = 0; i < m_ioNames.size(); i++)
{
@ -169,8 +155,7 @@ void DataReader<ElemType>::StartDistributedMinibatchLoop(size_t mbSize, size_t e
// matrices - [in] a map with named matrix types (i.e. 'features', 'labels') mapped to the corresponding matrix,
// [out] each matrix resized if necessary containing data.
// returns - true if there are more minibatches, false if no more minibatchs remain
template <class ElemType>
bool DataReader<ElemType>::GetMinibatch(StreamMinibatchInputs& matrices)
bool DataReader::GetMinibatch(StreamMinibatchInputs& matrices)
{
/**
each reader reads data with number of columns as nbr_utterances_per_minibatch * mbSize
@ -193,7 +178,7 @@ bool DataReader<ElemType>::GetMinibatch(StreamMinibatchInputs& matrices)
if (nbr == 0)
nbr = thisNbr;
else if (thisNbr != nbr)
LogicError("DataReader<ElemType>::GetMinibatch: The specified number of utterances per minibatch is not consistent to the actual number of utterances per minibatch");
LogicError("DataReader::GetMinibatch: The specified number of utterances per minibatch is not consistent to the actual number of utterances per minibatch");
}
return bRet;
}
@ -203,38 +188,31 @@ bool DataReader<ElemType>::GetMinibatch(StreamMinibatchInputs& matrices)
// uids - lables stored in size_t vector instead of ElemType matrix
// boundary - phone boundaries
// returns - true if there are more minibatches, false if no more minibatchs remain
template <class ElemType>
bool DataReader<ElemType>::GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput, vector<size_t>& uids, vector<size_t>& boundaries, vector<size_t>& extrauttmap)
bool DataReader::GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput, vector<size_t>& uids, vector<size_t>& boundaries, vector<size_t>& extrauttmap)
{
bool bRet = true;
for (size_t i = 0; i < m_ioNames.size(); i++)
{
bRet &= m_dataReaders[m_ioNames[i]]->GetMinibatch4SE(latticeinput, uids, boundaries, extrauttmap);
}
return bRet;
}
// GetHmmData - Get the HMM definition for SE training
// hmm - HMM definition
// returns - true if succeed
template <class ElemType>
bool DataReader<ElemType>::GetHmmData(msra::asr::simplesenonehmm* hmm)
bool DataReader::GetHmmData(msra::asr::simplesenonehmm* hmm)
{
bool bRet = true;
for (size_t i = 0; i < m_ioNames.size(); i++)
{
bRet &= m_dataReaders[m_ioNames[i]]->GetHmmData(hmm);
}
return bRet;
}
template <class ElemType>
size_t DataReader<ElemType>::GetNumParallelSequences()
size_t DataReader::GetNumParallelSequences()
{
size_t nNbr = 0;
for (size_t i = 0; i < m_ioNames.size(); i++)
{
IDataReader<ElemType>* ptr = m_dataReaders[m_ioNames[i]];
IDataReader* ptr = m_dataReaders[m_ioNames[i]];
if (nNbr == 0)
nNbr = ptr->GetNumParallelSequences();
else if (nNbr != ptr->GetNumParallelSequences())
@ -243,37 +221,13 @@ size_t DataReader<ElemType>::GetNumParallelSequences()
return nNbr;
}
#if 0
template <class ElemType>
bool DataReader<ElemType>::RequireSentenceSeg() const
{
bool ans = false;
for (size_t i = 0; i < m_ioNames.size(); i++)
ans = ans || m_dataReaders.find(m_ioNames[i])->second->RequireSentenceSeg(); // can't say m_dataReaders[] since that is non-const...
return ans;
}
#endif
template <class ElemType>
void DataReader<ElemType>::InitProposals(StreamMinibatchInputs* matrices)
void DataReader::InitProposals(StreamMinibatchInputs* matrices)
{
for (size_t i = 0; i < m_ioNames.size(); i++)
m_dataReaders[m_ioNames[i]]->InitProposals(matrices);
}
#if 0
template <class ElemType>
int DataReader<ElemType>::GetSentenceEndIdFromOutputLabel()
{
int iRet = -1;
for (size_t i = 0; i < m_ioNames.size(); i++)
iRet = m_dataReaders[m_ioNames[i]]->GetSentenceEndIdFromOutputLabel();
return iRet;
}
#endif
template <class ElemType>
bool DataReader<ElemType>::GetProposalObs(StreamMinibatchInputs* matrices, const size_t tidx, vector<size_t>& history)
bool DataReader::GetProposalObs(StreamMinibatchInputs* matrices, const size_t tidx, vector<size_t>& history)
{
bool bRet = true;
for (size_t i = 0; i < m_ioNames.size(); i++)
@ -281,23 +235,20 @@ bool DataReader<ElemType>::GetProposalObs(StreamMinibatchInputs* matrices, const
return bRet;
}
template <class ElemType>
void DataReader<ElemType>::CopyMBLayoutTo(MBLayoutPtr pMBLayout)
void DataReader::CopyMBLayoutTo(MBLayoutPtr pMBLayout)
{
// BUGBUG: This copies all data reader's layout info on top of each other, keeping only the last one; likely not what was intended.
for (size_t i = 0; i < m_ioNames.size(); i++)
m_dataReaders[m_ioNames[i]]->CopyMBLayoutTo(pMBLayout);
}
template <class ElemType>
void DataReader<ElemType>::SetRandomSeed(int seed)
void DataReader::SetRandomSeed(int seed)
{
for (size_t i = 0; i < m_ioNames.size(); i++)
m_dataReaders[m_ioNames[i]]->SetRandomSeed(seed);
}
template <class ElemType>
bool DataReader<ElemType>::GetMinibatchCopy(
bool DataReader::GetMinibatchCopy(
std::vector<std::vector<std::pair<wstring, size_t>>>& uttInfo,
StreamMinibatchInputs& matrices,
MBLayoutPtr pMBLayout)
@ -308,8 +259,7 @@ bool DataReader<ElemType>::GetMinibatchCopy(
return ans;
}
template <class ElemType>
bool DataReader<ElemType>::SetNetOutput(
bool DataReader::SetNetOutput(
const std::vector<std::vector<std::pair<wstring, size_t>>>& uttInfo,
const MatrixBase& outputs,
const MBLayoutPtr pMBLayout)
@ -322,8 +272,7 @@ bool DataReader<ElemType>::SetNetOutput(
// GetLabelMapping - Gets the label mapping from integer index to label type
// returns - a map from numeric datatype to native label type
template <class ElemType>
const std::map<typename DataReader<ElemType>::LabelIdType, typename DataReader<ElemType>::LabelType>& DataReader<ElemType>::GetLabelMapping(const std::wstring&)
const std::map<typename DataReader::LabelIdType, typename DataReader::LabelType>& DataReader::GetLabelMapping(const std::wstring&)
{
NOT_IMPLEMENTED;
}
@ -331,8 +280,7 @@ const std::map<typename DataReader<ElemType>::LabelIdType, typename DataReader<E
// SetLabelMapping - Sets the label mapping from integer index to label
// labelMapping - mapping table from label values to IDs (must be 0-n)
// note: for tasks with labels, the mapping table must be the same between a training run and a testing run
template <class ElemType>
void DataReader<ElemType>::SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, LabelType>& labelMapping)
void DataReader::SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, LabelType>& labelMapping)
{
for (size_t i = 0; i < m_ioNames.size(); i++)
m_dataReaders[m_ioNames[i]]->SetLabelMapping(sectionName, labelMapping);
@ -346,8 +294,7 @@ void DataReader<ElemType>::SetLabelMapping(const std::wstring& sectionName, cons
// [out] size of buffer filled with data
// recordStart - record to start reading from, defaults to zero (start of data)
// returns: true if data remains to be read, false if the end of data was reached
template <class ElemType>
bool DataReader<ElemType>::GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart)
bool DataReader::GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart)
{
bool bRet = true;
for (size_t i = 0; i < m_ioNames.size(); i++)
@ -355,8 +302,7 @@ bool DataReader<ElemType>::GetData(const std::wstring& sectionName, size_t numRe
return bRet;
}
template <class ElemType>
bool DataReader<ElemType>::DataEnd()
bool DataReader::DataEnd()
{
bool bRet = true;
for (size_t i = 0; i < m_ioNames.size(); i++)
@ -364,10 +310,7 @@ bool DataReader<ElemType>::DataEnd()
return bRet;
}
//The explicit instantiation
template class DataReader<double>;
template class DataReader<float>;
// register SGD<> with the ScriptableObject system
ScriptableObjects::ConfigurableRuntimeTypeRegister::AddFloatDouble<DataReader<float>, DataReader<double>> registerDataReaderPlugin(L"DataReaderPlugin");
} } }
ScriptableObjects::ConfigurableRuntimeTypeRegister::Add<DataReader> registerDataReaderPlugin(L"DataReaderPlugin");
}}}

Просмотреть файл

@ -11,29 +11,18 @@
namespace Microsoft { namespace MSR { namespace CNTK {
template <class ElemType>
std::string GetWriterName(ElemType)
static const char* GetWriterName(const string& precision)
{
std::string empty;
return empty;
if (precision == "float")
return "GetWriterF";
else if (precision == "double")
return "GetWriterD";
else
InvalidArgument("DataWriter: The 'precision' parameter must be 'float' or 'double'.");
}
template <>
std::string GetWriterName(float)
{
std::string name = "GetWriterF";
return name;
}
template <>
std::string GetWriterName(double)
{
std::string name = "GetWriterD";
return name;
}
template <class ElemType>
template <class ConfigRecordType>
void DataWriter<ElemType>::InitFromConfig(const ConfigRecordType& /*config*/)
void DataWriter::InitFromConfig(const ConfigRecordType& /*config*/)
{
RuntimeError("Init shouldn't be called, use constructor");
// not implemented, calls the underlying class instead
@ -41,22 +30,17 @@ void DataWriter<ElemType>::InitFromConfig(const ConfigRecordType& /*config*/)
// Destroy - cleanup and remove this class
// NOTE: this destroys the object, and it can't be used past this point
template <class ElemType>
void DataWriter<ElemType>::Destroy()
void DataWriter::Destroy()
{
m_dataWriter->Destroy();
}
// DataWriter Constructor
// config - [in] configuration data for the data writer
template <class ElemType>
template <class ConfigRecordType>
DataWriter<ElemType>::DataWriter(const ConfigRecordType& config)
DataWriter::DataWriter(const ConfigRecordType& config)
{
typedef void (*GetWriterProc)(IDataWriter<ElemType>** pwriter);
// initialize just in case
m_dataWriter = NULL;
typedef void (*GetWriterProc)(IDataWriter** pwriter);
// get the name for the writer we want to use, default to BinaryWriter (which is in BinaryReader.dll)
// TODO: This seems like a find-replace operation?
@ -82,21 +66,20 @@ DataWriter<ElemType>::DataWriter(const ConfigRecordType& config)
writerType = L"KaldiReader";
}
ElemType elemType = ElemType();
GetWriterProc getWriterProc = (GetWriterProc) Plugin::Load(writerType, GetWriterName(elemType));
string precision = config(L"precision", "float");
GetWriterProc getWriterProc = (GetWriterProc)Plugin::Load(writerType, GetWriterName(precision));
m_dataWriter = NULL;
getWriterProc(&m_dataWriter);
m_dataWriter->Init(config);
}
template DataWriter<float>::DataWriter(const ConfigParameters&);
template DataWriter<double>::DataWriter(const ConfigParameters&);
template DataWriter<float>::DataWriter(const ScriptableObjects::IConfigRecord&);
template DataWriter<double>::DataWriter(const ScriptableObjects::IConfigRecord&);
template DataWriter::DataWriter(const ConfigParameters&);
template DataWriter::DataWriter(const ScriptableObjects::IConfigRecord&);
// destructor - cleanup temp files, etc.
template <class ElemType>
DataWriter<ElemType>::~DataWriter()
DataWriter::~DataWriter()
{
// free up resources
if (m_dataWriter)
@ -105,8 +88,7 @@ DataWriter<ElemType>::~DataWriter()
// GetSections - Get the sections of the file
// sections - a map of section name to section. Data sepcifications from config file will be used to determine where and how to save data
template <class ElemType>
void DataWriter<ElemType>::GetSections(std::map<std::wstring, SectionType, nocase_compare>& sections)
void DataWriter::GetSections(std::map<std::wstring, SectionType, nocase_compare>& sections)
{
m_dataWriter->GetSections(sections);
}
@ -117,8 +99,7 @@ void DataWriter<ElemType>::GetSections(std::map<std::wstring, SectionType, nocas
// numRecords - number of records we are saving, can be zero if not applicable
// datasetSize - Size of the dataset
// byteVariableSized - for variable sized data, size of current block to be written, zero when not used, or ignored if not variable sized data
template <class ElemType>
bool DataWriter<ElemType>::SaveData(size_t recordStart, const std::map<std::wstring, void*, nocase_compare>& matrices, size_t numRecords, size_t datasetSize, size_t byteVariableSized)
bool DataWriter::SaveData(size_t recordStart, const std::map<std::wstring, void*, nocase_compare>& matrices, size_t numRecords, size_t datasetSize, size_t byteVariableSized)
{
return m_dataWriter->SaveData(recordStart, matrices, numRecords, datasetSize, byteVariableSized);
}
@ -126,13 +107,9 @@ bool DataWriter<ElemType>::SaveData(size_t recordStart, const std::map<std::wstr
// SaveMapping - save a map into the file
// saveId - name of the section to save into (section:subsection format)
// labelMapping - map we are saving to the file
template <class ElemType>
void DataWriter<ElemType>::SaveMapping(std::wstring saveId, const std::map<LabelIdType, LabelType>& labelMapping)
void DataWriter::SaveMapping(std::wstring saveId, const std::map<LabelIdType, LabelType>& labelMapping)
{
m_dataWriter->SaveMapping(saveId, labelMapping);
}
//The explicit instantiation
template class DataWriter<double>;
template class DataWriter<float>;
} } }
}}}

Просмотреть файл

@ -85,14 +85,13 @@ public:
// Data Reader interface
// implemented by DataReader and underlying classes
// TODO: Remove <ElemType>. Only one method to go: SetNetOutput().
template <class ElemType>
class DATAREADER_API IDataReader
{
public:
typedef std::string LabelType; // surface form of an input token
typedef unsigned int LabelIdType; // input token mapped to an integer --TODO: why not size_t? Does this save space?
// BUGBUG: We should not have data members in an interace!
unsigned m_seed;
size_t mRequestedNumParallelSequences; // number of desired parallel sequences in each minibatch
@ -199,27 +198,20 @@ public:
return false;
}
};
typedef std::shared_ptr<IDataReader> IDataReaderPtr;
// GetReader - get a reader type from the DLL
// since we have 2 reader types based on template parameters, exposes 2 exports
// could be done directly the templated name, but that requires mangled C++ names
template <class ElemType>
void DATAREADER_API GetReader(IDataReader<ElemType>** preader);
extern "C" DATAREADER_API void GetReaderF(IDataReader<float>** preader);
extern "C" DATAREADER_API void GetReaderD(IDataReader<double>** preader);
// GetReaderX() - get a reader type from the DLL
// The F version gets the 'float' version, and D gets 'double'.
extern "C" DATAREADER_API void GetReaderF(IDataReader** preader);
extern "C" DATAREADER_API void GetReaderD(IDataReader** preader);
// Data Reader class
// interface for clients of the Data Reader
// mirrors the IDataReader interface, except the Init method is private (use the constructor)
template <class ElemType>
class DataReader : public IDataReader<ElemType>, protected Plugin, public ScriptableObjects::Object
class DataReader : public IDataReader, protected Plugin, public ScriptableObjects::Object
{
typedef typename IDataReader<ElemType>::LabelType LabelType;
typedef typename IDataReader<ElemType>::LabelIdType LabelIdType;
private:
vector<wstring> m_ioNames; // TODO: why are these needed, why not loop over m_dataReaders?
map<wstring, IDataReader<ElemType>*> m_dataReaders; // readers
map<wstring, IDataReader*> m_dataReaders; // readers
// Init - Reader Initialize for multiple data sets
// config - [in] configuration parameters for the datareader

Просмотреть файл

@ -46,7 +46,6 @@ enum SectionType
// Data Writer interface
// implemented by some DataWriters
template <class ElemType>
class DATAWRITER_API IDataWriter
{
public:
@ -64,26 +63,19 @@ public:
virtual void SaveMapping(std::wstring saveId, const std::map<LabelIdType, LabelType>& labelMapping) = 0;
virtual bool SupportMultiUtterances() const = 0;
};
typedef std::shared_ptr<IDataWriter> IDataWriterPtr;
// GetWriter - get a reader type from the DLL
// since we have 2 writerr types based on template parameters, exposes 2 exports
// could be done directly the templated name, but that requires mangled C++ names
template <class ElemType>
void DATAWRITER_API GetWriter(IDataWriter<ElemType>** pwriter);
extern "C" DATAWRITER_API void GetWriterF(IDataWriter<float>** pwriter);
extern "C" DATAWRITER_API void GetWriterD(IDataWriter<double>** pwriter);
// The F version gets the 'float' version, and D gets 'double'.
extern "C" DATAWRITER_API void GetWriterF(IDataWriter** pwriter);
extern "C" DATAWRITER_API void GetWriterD(IDataWriter** pwriter);
// Data Writer class
// interface for clients of the Data Writer
// mirrors the IDataWriter interface, except the Init method is private (use the constructor)
template <class ElemType>
class DataWriter : public IDataWriter<ElemType>, protected Plugin
class DataWriter : public IDataWriter, protected Plugin
{
typedef typename IDataWriter<ElemType>::LabelType LabelType;
typedef typename IDataWriter<ElemType>::LabelIdType LabelIdType;
private:
IDataWriter<ElemType>* m_dataWriter; // writer
IDataWriter* m_dataWriter; // writer
// Init - Writer Initialize for multiple data sets
// config - [in] configuration parameters for the datawriter

Просмотреть файл

@ -574,6 +574,7 @@ public:
// helper to access to element(0,0) without having to type-cast
virtual double Get00Element() const = 0;
virtual MatrixBasePtr ValuePtr() const = 0; // for use in readers that pass the agnostic object around
// TODO: two sets of functions, choose one
const std::wstring& NodeName() const { return m_nodeName; }
@ -1094,7 +1095,8 @@ public:
const Matrix<ElemType>& Value() const { return *m_value; }
Matrix<ElemType>& Value() { return *m_value; }
std::shared_ptr<Matrix<ElemType>> ValuePtr() { return m_value; } // readers want this as a shared_ptr straight
MatrixBasePtr ValuePtr() const override final { return m_value; } // readers want this as a shared_ptr straight
// Note: We cannot return a const& since returning m_value as a MatrixBasePtr is a type cast that generates a temporary. Interesting.
const Matrix<ElemType>& Gradient() const { return *m_gradient; }
Matrix<ElemType>& Gradient() { return *m_gradient; }
@ -1677,6 +1679,7 @@ public:
virtual void CopyTo(ComputationNodeBasePtr node, const std::wstring& newName, const CopyNodeFlags flags) const override { NOT_IMPLEMENTED; }
virtual ComputationNodeBasePtr Duplicate(const std::wstring& newName, const CopyNodeFlags flags) override { NOT_IMPLEMENTED; }
virtual double Get00Element() const override { NOT_IMPLEMENTED; }
virtual MatrixBasePtr ValuePtr() const override { NOT_IMPLEMENTED; }
virtual void UpdateFunctionMBSize() override { NOT_IMPLEMENTED; }
virtual void AttachInputs(const std::vector<ComputationNodeBasePtr>& inputs) override { NOT_IMPLEMENTED; }
virtual void PrintSelf(bool) const override { NOT_IMPLEMENTED; }

Просмотреть файл

@ -12,12 +12,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// Evaluation Reader class
// interface to pass to evaluation DLL
template <class ElemType>
class EvalReader : public IDataReader<ElemType>
class EvalReader : public IDataReader
{
typedef typename IDataReader<ElemType>::LabelType LabelType;
typedef typename IDataReader<ElemType>::LabelIdType LabelIdType;
private:
std::map<std::wstring, std::vector<ElemType>*>* m_inputs; // our input data
std::map<std::wstring, size_t>* m_dimensions; // the number of rows for the input data
size_t m_recordCount; // count of records in this data

Просмотреть файл

@ -12,16 +12,13 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// Evaluation Writer class
// interface to pass to evaluation DLL
template <class ElemType>
class EvalWriter : public IDataWriter<ElemType>
class EvalWriter : public IDataWriter
{
typedef typename IDataWriter<ElemType>::LabelType LabelType;
typedef typename IDataWriter<ElemType>::LabelIdType LabelIdType;
private:
std::map<std::wstring, std::vector<ElemType>*>* m_outputs; // our output data
std::map<std::wstring, size_t>* m_dimensions; // the number of rows for the output data
size_t m_recordCount; // count of records in this data
size_t m_currentRecord; // next record number to read
public:
// Method to setup the data for the reader
void SetData(std::map<std::wstring, std::vector<ElemType>*>* outputs, std::map<std::wstring, size_t>* dimensions)

Просмотреть файл

@ -541,12 +541,8 @@ public:
};
template <class ElemType>
class BinaryReader : public IDataReader<ElemType>
class BinaryReader : public IDataReader
{
typedef typename IDataReader<ElemType>::LabelType LabelType;
typedef typename IDataReader<ElemType>::LabelIdType LabelIdType;
private:
size_t m_mbSize; // size of minibatch requested
size_t m_mbStartSample; // starting sample # of the next minibatch
size_t m_epochSize; // size of an epoch
@ -615,12 +611,8 @@ public:
};
template <class ElemType>
class BinaryWriter : public IDataWriter<ElemType>
class BinaryWriter : public IDataWriter
{
typedef typename IDataWriter<ElemType>::LabelType LabelType;
typedef typename IDataWriter<ElemType>::LabelIdType LabelIdType;
private:
int m_traceLevel; // trace level to output the
size_t m_recordCurrent;
size_t m_recordMax;

Просмотреть файл

@ -14,33 +14,22 @@
namespace Microsoft { namespace MSR { namespace CNTK {
template <class ElemType>
void DATAREADER_API GetReader(IDataReader<ElemType>** preader)
extern "C" DATAREADER_API void GetReaderF(IDataReader** preader)
{
*preader = new BinaryReader<ElemType>();
*preader = new BinaryReader<float>();
}
extern "C" DATAREADER_API void GetReaderD(IDataReader** preader)
{
*preader = new BinaryReader<double>();
}
extern "C" DATAREADER_API void GetReaderF(IDataReader<float>** preader)
extern "C" DATAWRITER_API void GetWriterF(IDataWriter** pwriter)
{
GetReader(preader);
*pwriter = new BinaryWriter<float>();
}
extern "C" DATAREADER_API void GetReaderD(IDataReader<double>** preader)
extern "C" DATAWRITER_API void GetWriterD(IDataWriter** pwriter)
{
GetReader(preader);
*pwriter = new BinaryWriter<double>();
}
template <class ElemType>
void DATAWRITER_API GetWriter(IDataWriter<ElemType>** pwriter)
{
*pwriter = new BinaryWriter<ElemType>();
}
extern "C" DATAWRITER_API void GetWriterF(IDataWriter<float>** pwriter)
{
GetWriter(pwriter);
}
extern "C" DATAWRITER_API void GetWriterD(IDataWriter<double>** pwriter)
{
GetWriter(pwriter);
}
} } }
}}}

Просмотреть файл

@ -418,7 +418,7 @@ bool DSSMReader<ElemType>::GetMinibatch(StreamMinibatchInputs& matrices)
// GetLabelMapping - Gets the label mapping from integer index to label type
// returns - a map from numeric datatype to native label type
template <class ElemType>
const std::map<typename IDataReader<ElemType>::LabelIdType, typename IDataReader<ElemType>::LabelType>& DSSMReader<ElemType>::GetLabelMapping(const std::wstring& sectionName)
const std::map<IDataReader::LabelIdType, IDataReader::LabelType>& DSSMReader<ElemType>::GetLabelMapping(const std::wstring& sectionName)
{
if (m_cachingReader)
{
@ -431,7 +431,7 @@ const std::map<typename IDataReader<ElemType>::LabelIdType, typename IDataReader
// labelMapping - mapping table from label values to IDs (must be 0-n)
// note: for tasks with labels, the mapping table must be the same between a training run and a testing run
template <class ElemType>
void DSSMReader<ElemType>::SetLabelMapping(const std::wstring& /*sectionName*/, const std::map<typename IDataReader<ElemType>::LabelIdType, typename LabelType>& labelMapping)
void DSSMReader<ElemType>::SetLabelMapping(const std::wstring& /*sectionName*/, const std::map<LabelIdType, LabelType>& labelMapping)
{
if (m_cachingReader)
{

Просмотреть файл

@ -64,7 +64,7 @@ public:
};
template <class ElemType>
class DSSMReader : public IDataReader<ElemType>
class DSSMReader : public IDataReader
{
// public:
// typedef std::string LabelType;
@ -119,8 +119,8 @@ private:
std::map<LabelType, LabelIdType> m_mapLabelToId;
// caching support
DataReader<ElemType>* m_cachingReader;
DataWriter<ElemType>* m_cachingWriter;
DataReader* m_cachingReader;
DataWriter* m_cachingWriter;
ConfigParameters m_readerConfig;
size_t RandomizeSweep(size_t epochSample);
@ -172,7 +172,7 @@ public:
}
virtual const std::map<LabelIdType, LabelType>& GetLabelMapping(const std::wstring& sectionName);
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, typename LabelType>& labelMapping);
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, LabelType>& labelMapping);
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart = 0);
virtual bool DataEnd();

Просмотреть файл

@ -12,18 +12,13 @@
namespace Microsoft { namespace MSR { namespace CNTK {
template <class ElemType>
void DATAREADER_API GetReader(IDataReader<ElemType>** preader)
extern "C" DATAREADER_API void GetReaderF(IDataReader** preader)
{
*preader = new DSSMReader<ElemType>();
*preader = new DSSMReader<float>();
}
extern "C" DATAREADER_API void GetReaderD(IDataReader** preader)
{
*preader = new DSSMReader<double>();
}
extern "C" DATAREADER_API void GetReaderF(IDataReader<float>** preader)
{
GetReader(preader);
}
extern "C" DATAREADER_API void GetReaderD(IDataReader<double>** preader)
{
GetReader(preader);
}
} } }
}}}

Просмотреть файл

@ -4,6 +4,8 @@
//
// DataWriter.cpp : Defines the exported functions for the DLL application.
//
// TODO: This is similar but not identical to Common/DataWriter.cpp. Why is this DataWriter different? Can it be reconciled?
//
#include "stdafx.h"
#include "Basics.h"
@ -16,59 +18,45 @@
namespace Microsoft { namespace MSR { namespace CNTK {
template <class ElemType>
void DATAWRITER_API GetWriter(IDataWriter<ElemType>** pwriter)
{
*pwriter = new HTKMLFWriter<ElemType>();
}
extern "C" DATAWRITER_API void GetWriterF(IDataWriter<float>** pwriter)
{
GetWriter(pwriter);
}
extern "C" DATAWRITER_API void GetWriterD(IDataWriter<double>** pwriter)
{
GetWriter(pwriter);
}
template <class ElemType>
template <class ConfigRecordType>
void DataWriter<ElemType>::InitFromConfig(const ConfigRecordType& writerConfig)
void DataWriter::InitFromConfig(const ConfigRecordType& writerConfig)
{
m_dataWriter = new HTKMLFWriter<ElemType>();
wstring precision = writerConfig(L"precision", L"float");
if (precision == L"float")
m_dataWriter = new HTKMLFWriter<float>();
else if (precision == L"double")
m_dataWriter = new HTKMLFWriter<double>();
else
InvalidArgument("DataWriter (HTKMLFWriter): The 'precision' parameter must be 'float' or 'double'.");
m_dataWriter->Init(writerConfig);
}
// Destroy - cleanup and remove this class
// NOTE: this destroys the object, and it can't be used past this point
template <class ElemType>
void DataWriter<ElemType>::Destroy()
void DataWriter::Destroy()
{
delete m_dataWriter;
m_dataWriter = NULL;
// TODO: do we need to destroy ourselves as well?
}
// DataWriter Constructor
// config - [in] configuration data for the data writer
template <class ElemType>
template <class ConfigRecordType>
DataWriter<ElemType>::DataWriter(const ConfigRecordType& config)
DataWriter::DataWriter(const ConfigRecordType& config)
{
Init(config);
}
// destructor - cleanup temp files, etc.
template <class ElemType>
DataWriter<ElemType>::~DataWriter()
DataWriter::~DataWriter()
{
delete m_dataWriter;
m_dataWriter = NULL;
Destroy();
}
// GetSections - Get the sections of the file
// sections - a map of section name to section. Data sepcifications from config file will be used to determine where and how to save data
template <class ElemType>
void DataWriter<ElemType>::GetSections(std::map<std::wstring, SectionType, nocase_compare>& sections)
void DataWriter::GetSections(std::map<std::wstring, SectionType, nocase_compare>& sections)
{
m_dataWriter->GetSections(sections);
}
@ -79,8 +67,7 @@ void DataWriter<ElemType>::GetSections(std::map<std::wstring, SectionType, nocas
// numRecords - number of records we are saving, can be zero if not applicable
// datasetSize - Size of the dataset
// byteVariableSized - for variable sized data, size of current block to be written, zero when not used, or ignored if not variable sized data
template <class ElemType>
bool DataWriter<ElemType>::SaveData(size_t recordStart, const std::map<std::wstring, void*, nocase_compare>& matrices, size_t numRecords, size_t datasetSize, size_t byteVariableSized)
bool DataWriter::SaveData(size_t recordStart, const std::map<std::wstring, void*, nocase_compare>& matrices, size_t numRecords, size_t datasetSize, size_t byteVariableSized)
{
return m_dataWriter->SaveData(recordStart, matrices, numRecords, datasetSize, byteVariableSized);
}
@ -88,13 +75,9 @@ bool DataWriter<ElemType>::SaveData(size_t recordStart, const std::map<std::wstr
// SaveMapping - save a map into the file
// saveId - name of the section to save into (section:subsection format)
// labelMapping - map we are saving to the file
template <class ElemType>
void DataWriter<ElemType>::SaveMapping(std::wstring saveId, const std::map<LabelIdType, LabelType>& labelMapping)
void DataWriter::SaveMapping(std::wstring saveId, const std::map<LabelIdType, LabelType>& labelMapping)
{
m_dataWriter->SaveMapping(saveId, labelMapping);
}
//The explicit instantiation
template class DataWriter<double>;
template class DataWriter<float>;
} } }
}}}

Просмотреть файл

@ -8,36 +8,31 @@
#include "stdafx.h"
#include "Basics.h"
#include "htkfeatio.h" // for reading HTK features
#ifdef _WIN32
#include "latticearchive.h" // for reading HTK phoneme lattices (MMI training)
#endif
#include "simplesenonehmm.h" // for MMI scoring
#include "msra_mgram.h" // for unigram scores of ground-truth path in sequence training
#include "rollingwindowsource.h" // minibatch sources
#include "chunkevalsource.h"
#define DATAREADER_EXPORTS
#include "DataReader.h"
#define DATAWRITER_EXPORTS
#include "HTKMLFReader.h"
#include "Config.h"
#include "HTKMLFWriter.h"
namespace Microsoft { namespace MSR { namespace CNTK {
template <class ElemType>
void DATAREADER_API GetReader(IDataReader<ElemType>** preader)
extern "C" DATAREADER_API void GetReaderF(IDataReader** preader)
{
*preader = new HTKMLFReader<ElemType>();
*preader = new HTKMLFReader<float>();
}
extern "C" DATAREADER_API void GetReaderD(IDataReader** preader)
{
*preader = new HTKMLFReader<double>();
}
extern "C" DATAREADER_API void GetReaderF(IDataReader<float>** preader)
extern "C" DATAWRITER_API void GetWriterF(IDataWriter** pwriter)
{
GetReader(preader);
*pwriter = new HTKMLFWriter<float>();
}
extern "C" DATAREADER_API void GetReaderD(IDataReader<double>** preader)
extern "C" DATAWRITER_API void GetWriterD(IDataWriter** pwriter)
{
GetReader(preader);
*pwriter = new HTKMLFWriter<double>();
}
#ifdef _WIN32
// Utility function, in ConfigFile.cpp, but HTKMLFReader doesn't need that code...
@ -58,4 +53,5 @@ void Trim(std::string& str)
str.erase(found + 1);
}
#endif
} } }
}}}

Просмотреть файл

@ -1732,7 +1732,7 @@ bool HTKMLFReader<ElemType>::ReNewBufferForMultiIO(size_t i)
// GetLabelMapping - Gets the label mapping from integer to type in file
// mappingTable - a map from numeric datatype to native label type stored as a string
template <class ElemType>
const std::map<typename IDataReader<ElemType>::LabelIdType, typename IDataReader<ElemType>::LabelType>& HTKMLFReader<ElemType>::GetLabelMapping(const std::wstring& /*sectionName*/)
const std::map<IDataReader::LabelIdType, IDataReader::LabelType>& HTKMLFReader<ElemType>::GetLabelMapping(const std::wstring& /*sectionName*/)
{
return m_idToLabelMap;
}
@ -1741,7 +1741,7 @@ const std::map<typename IDataReader<ElemType>::LabelIdType, typename IDataReader
// labelMapping - mapping table from label values to IDs (must be 0-n)
// note: for tasks with labels, the mapping table must be the same between a training run and a testing run
template <class ElemType>
void HTKMLFReader<ElemType>::SetLabelMapping(const std::wstring& /*sectionName*/, const std::map<typename IDataReader<ElemType>::LabelIdType, typename IDataReader<ElemType>::LabelType>& labelMapping)
void HTKMLFReader<ElemType>::SetLabelMapping(const std::wstring& /*sectionName*/, const std::map<IDataReader::LabelIdType, IDataReader::LabelType>& labelMapping)
{
m_idToLabelMap = labelMapping;
}

Просмотреть файл

@ -9,10 +9,19 @@
#include "Config.h" // for intargvector
#include "CUDAPageLockedMemAllocator.h"
#include "htkfeatio.h" // for reading HTK features
#ifdef _WIN32
#include "latticearchive.h" // for reading HTK phoneme lattices (MMI training)
#endif
#include "simplesenonehmm.h" // for MMI scoring
#include "msra_mgram.h" // for unigram scores of ground-truth path in sequence training
#include "rollingwindowsource.h" // minibatch sources
#include "chunkevalsource.h"
namespace Microsoft { namespace MSR { namespace CNTK {
template <class ElemType>
class HTKMLFReader : public IDataReader<ElemType>
class HTKMLFReader : public IDataReader
{
private:
const static size_t m_htkRandomizeAuto = 0;
@ -41,8 +50,8 @@ private:
size_t m_extraNumSeqs;
bool m_noData;
bool m_trainOrTest; // if false, in file writing mode
using LabelType = typename IDataReader<ElemType>::LabelType;
using LabelIdType = typename IDataReader<ElemType>::LabelIdType;
using IDataReader::LabelType;
using IDataReader::LabelIdType;
std::map<LabelIdType, LabelType> m_idToLabelMap;

Просмотреть файл

@ -118,8 +118,8 @@
<ClCompile Include="..\..\Common\TimerUtility.cpp">
<PrecompiledHeader>NotUsing</PrecompiledHeader>
</ClCompile>
<ClCompile Include="DataReader.cpp" />
<ClCompile Include="DataWriter.cpp" />
<ClCompile Include="Exports.cpp" />
<ClCompile Include="DataWriterLocal.cpp" />
<ClCompile Include="dllmain.cpp">
<CompileAsManaged Condition="$(DebugBuild)">false</CompileAsManaged>
<PrecompiledHeader Condition="$(DebugBuild)">

Просмотреть файл

@ -1,13 +1,11 @@
<?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup>
<ClCompile Include="DataWriter.cpp" />
<ClCompile Include="dllmain.cpp" />
<ClCompile Include="HTKMLFReader.cpp" />
<ClCompile Include="HTKMLFWriter.cpp" />
<ClCompile Include="latticearchive.cpp" />
<ClCompile Include="stdafx.cpp" />
<ClCompile Include="DataReader.cpp" />
<ClCompile Include="..\..\Common\TimerUtility.cpp">
<Filter>Common</Filter>
</ClCompile>
@ -15,6 +13,8 @@
<Filter>Common</Filter>
</ClCompile>
<ClCompile Include="..\..\Common\ExceptionWithCallStack.cpp" />
<ClCompile Include="Exports.cpp" />
<ClCompile Include="DataWriterLocal.cpp" />
</ItemGroup>
<ItemGroup>
<ClInclude Include="biggrowablevectors.h" />

Просмотреть файл

@ -13,7 +13,7 @@
namespace Microsoft { namespace MSR { namespace CNTK {
template <class ElemType>
class HTKMLFWriter : public IDataWriter<ElemType>
class HTKMLFWriter : public IDataWriter
{
private:
std::vector<size_t> outputDims;
@ -36,8 +36,6 @@ private:
};
public:
using LabelType = typename IDataWriter<ElemType>::LabelType;
using LabelIdType = typename IDataWriter<ElemType>::LabelIdType;
template <class ConfigRecordType>
void InitFromConfig(const ConfigRecordType& writerConfig);
virtual void Init(const ConfigParameters& config)

Просмотреть файл

@ -22,13 +22,14 @@ auto factory = [](const ConfigParameters& parameters) -> ReaderPtr
return std::make_shared<ImageReader>(std::make_shared<HeapMemoryProvider>(), parameters);
};
extern "C" DATAREADER_API void GetReaderF(IDataReader<float>** preader)
extern "C" DATAREADER_API void GetReaderF(IDataReader** preader)
{
*preader = new ReaderShim<float>(factory);
}
extern "C" DATAREADER_API void GetReaderD(IDataReader<double>** preader)
extern "C" DATAREADER_API void GetReaderD(IDataReader** preader)
{
*preader = new ReaderShim<double>(factory);
}
} } }
}}}

Просмотреть файл

@ -4,6 +4,8 @@
//
// DataReader.cpp : Defines the exported functions for the DLL application.
//
// TODO: Rename to Exports.cpp
//
#include "stdafx.h"
#include "basetypes.h"
@ -23,19 +25,13 @@
namespace Microsoft { namespace MSR { namespace CNTK {
template <class ElemType>
void DATAREADER_API GetReader(IDataReader<ElemType>** preader)
extern "C" DATAREADER_API void GetReaderF(IDataReader** preader)
{
*preader = new HTKMLFReader<ElemType>();
*preader = new HTKMLFReader<float>();
}
extern "C" DATAREADER_API void GetReaderF(IDataReader<float>** preader)
extern "C" DATAREADER_API void GetReaderD(IDataReader** preader)
{
GetReader(preader);
}
extern "C" DATAREADER_API void GetReaderD(IDataReader<double>** preader)
{
GetReader(preader);
*preader = new HTKMLFReader<double>();
}
// Utility function, in ConfigFile.cpp, but HTKMLFReader doesn't need that code...

Просмотреть файл

@ -4,6 +4,8 @@
//
// DataWriter.cpp : Defines the exported functions for the DLL application.
//
// TODO: This is similar but not identical to Common/DataWriter.cpp. Why is this DataWriter different? Can it be reconciled?
//
#include "stdafx.h"
#include "basetypes.h"
@ -16,59 +18,54 @@
namespace Microsoft { namespace MSR { namespace CNTK {
template <class ElemType>
void DATAWRITER_API GetWriter(IDataWriter<ElemType>** pwriter)
// TODO: move these to Exports.cpp
extern "C" DATAWRITER_API void GetWriterF(IDataWriter** pwriter)
{
*pwriter = new HTKMLFWriter<ElemType>();
*pwriter = new HTKMLFWriter<float>();
}
extern "C" DATAWRITER_API void GetWriterD(IDataWriter** pwriter)
{
*pwriter = new HTKMLFWriter<double>();
}
extern "C" DATAWRITER_API void GetWriterF(IDataWriter<float>** pwriter)
{
GetWriter(pwriter);
}
extern "C" DATAWRITER_API void GetWriterD(IDataWriter<double>** pwriter)
{
GetWriter(pwriter);
}
template <class ElemType>
template <class ConfigRecordType>
void DataWriter<ElemType>::InitFromConfig(const ConfigRecordType& writerConfig)
void DataWriter::InitFromConfig(const ConfigRecordType& writerConfig)
{
m_dataWriter = new HTKMLFWriter<ElemType>();
wstring precision = writerConfig(L"precision", L"float");
if (precision == L"float")
m_dataWriter = new HTKMLFWriter<float>();
else if (precision == L"double")
m_dataWriter = new HTKMLFWriter<double>();
else
InvalidArgument("DataWriter (Kaldi HTKMLFWriter): The 'precision' parameter must be 'float' or 'double'.");
m_dataWriter->Init(writerConfig);
}
// Destroy - cleanup and remove this class
// NOTE: this destroys the object, and it can't be used past this point
template <class ElemType>
void DataWriter<ElemType>::Destroy()
void DataWriter::Destroy()
{
delete m_dataWriter;
m_dataWriter = NULL;
}
// DataWriter Constructor
// config - [in] configuration data for the data writer
template <class ElemType>
template <class ConfigRecordType>
DataWriter<ElemType>::DataWriter(const ConfigRecordType& config)
DataWriter::DataWriter(const ConfigRecordType& config)
{
Init(config);
}
// destructor - cleanup temp files, etc.
template <class ElemType>
DataWriter<ElemType>::~DataWriter()
DataWriter::~DataWriter()
{
delete m_dataWriter;
m_dataWriter = NULL;
Destroy();
}
// GetSections - Get the sections of the file
// sections - a map of section name to section. Data sepcifications from config file will be used to determine where and how to save data
template <class ElemType>
void DataWriter<ElemType>::GetSections(std::map<std::wstring, SectionType, nocase_compare>& sections)
void DataWriter::GetSections(std::map<std::wstring, SectionType, nocase_compare>& sections)
{
m_dataWriter->GetSections(sections);
}
@ -79,8 +76,7 @@ void DataWriter<ElemType>::GetSections(std::map<std::wstring, SectionType, nocas
// numRecords - number of records we are saving, can be zero if not applicable
// datasetSize - Size of the dataset
// byteVariableSized - for variable sized data, size of current block to be written, zero when not used, or ignored if not variable sized data
template <class ElemType>
bool DataWriter<ElemType>::SaveData(size_t recordStart, const std::map<std::wstring, void*, nocase_compare>& matrices, size_t numRecords, size_t datasetSize, size_t byteVariableSized)
bool DataWriter::SaveData(size_t recordStart, const std::map<std::wstring, void*, nocase_compare>& matrices, size_t numRecords, size_t datasetSize, size_t byteVariableSized)
{
return m_dataWriter->SaveData(recordStart, matrices, numRecords, datasetSize, byteVariableSized);
}
@ -88,13 +84,9 @@ bool DataWriter<ElemType>::SaveData(size_t recordStart, const std::map<std::wstr
// SaveMapping - save a map into the file
// saveId - name of the section to save into (section:subsection format)
// labelMapping - map we are saving to the file
template <class ElemType>
void DataWriter<ElemType>::SaveMapping(std::wstring saveId, const std::map<LabelIdType, LabelType>& labelMapping)
void DataWriter::SaveMapping(std::wstring saveId, const std::map<LabelIdType, LabelType>& labelMapping)
{
m_dataWriter->SaveMapping(saveId, labelMapping);
}
//The explicit instantiation
template class DataWriter<double>;
template class DataWriter<float>;
} } }
}}}

Просмотреть файл

@ -1916,7 +1916,7 @@ bool HTKMLFReader<ElemType>::SetNetOutput(
// GetLabelMapping - Gets the label mapping from integer to type in file
// mappingTable - a map from numeric datatype to native label type stored as a string
template <class ElemType>
const std::map<typename IDataReader<ElemType>::LabelIdType, typename IDataReader<ElemType>::LabelType>& HTKMLFReader<ElemType>::GetLabelMapping(const std::wstring& /*sectionName*/)
const std::map<IDataReader::LabelIdType, IDataReader::LabelType>& HTKMLFReader<ElemType>::GetLabelMapping(const std::wstring& /*sectionName*/)
{
return m_idToLabelMap;
}
@ -1925,7 +1925,7 @@ const std::map<typename IDataReader<ElemType>::LabelIdType, typename IDataReader
// labelMapping - mapping table from label values to IDs (must be 0-n)
// note: for tasks with labels, the mapping table must be the same between a training run and a testing run
template <class ElemType>
void HTKMLFReader<ElemType>::SetLabelMapping(const std::wstring& /*sectionName*/, const std::map<typename IDataReader<ElemType>::LabelIdType, LabelType>& labelMapping)
void HTKMLFReader<ElemType>::SetLabelMapping(const std::wstring& /*sectionName*/, const std::map<LabelIdType, LabelType>& labelMapping)
{
m_idToLabelMap = labelMapping;
}

Просмотреть файл

@ -13,7 +13,7 @@
namespace Microsoft { namespace MSR { namespace CNTK {
template <class ElemType>
class HTKMLFReader : public IDataReader<ElemType>
class HTKMLFReader : public IDataReader
{
private:
msra::dbn::minibatchiterator* m_mbiter;
@ -71,8 +71,6 @@ private:
bool m_noData;
bool m_trainOrTest; // if false, in file writing mode
using LabelType = typename IDataReader<ElemType>::LabelType;
using LabelIdType = typename IDataReader<ElemType>::LabelIdType;
std::map<LabelIdType, LabelType> m_idToLabelMap;

Просмотреть файл

@ -12,7 +12,7 @@
namespace Microsoft { namespace MSR { namespace CNTK {
template <class ElemType>
class HTKMLFWriter : public IDataWriter<ElemType>
class HTKMLFWriter : public IDataWriter
{
private:
std::vector<size_t> outputDims;

Просмотреть файл

@ -7,23 +7,28 @@
#include "stdafx.h"
#define DATAREADER_EXPORTS
#include "DataReader.h"
#define DATAWRITER_EXPORTS
#include "SequenceReader.h"
#include "SequenceWriter.h"
namespace Microsoft { namespace MSR { namespace CNTK {
template <class ElemType>
void DATAREADER_API GetReader(IDataReader<ElemType>** preader)
extern "C" DATAREADER_API void GetReaderF(IDataReader** preader)
{
*preader = new BatchSequenceReader<ElemType>();
*preader = new BatchSequenceReader<float>();
}
extern "C" DATAREADER_API void GetReaderD(IDataReader** preader)
{
*preader = new BatchSequenceReader<double>();
}
extern "C" DATAREADER_API void GetReaderF(IDataReader<float>** preader)
extern "C" DATAWRITER_API void GetWriterF(IDataWriter** pwriter)
{
GetReader(preader);
*pwriter = new LMSequenceWriter<float>();
}
extern "C" DATAREADER_API void GetReaderD(IDataReader<double>** preader)
extern "C" DATAWRITER_API void GetWriterD(IDataWriter** pwriter)
{
GetReader(preader);
*pwriter = new LMSequenceWriter<double>();
}
} } }
}}}

Просмотреть файл

@ -12,6 +12,7 @@
#ifdef LEAKDETECT
#include <vld.h> // leak detection
#endif
#include "DataWriter.h"
#include "fileutil.h" // for fexists()
#include <iostream>
#include <vector>
@ -68,7 +69,7 @@ size_t SequenceReader<ElemType>::RecordsToRead(size_t mbStartSample, bool tail)
// endOfDataCheck - check if we are at the end of the dataset (no wraparound)
// returns - true if we have more to read, false if we hit the end of the dataset
template <class ElemType>
typename IDataReader<ElemType>::LabelIdType SequenceReader<ElemType>::GetIdFromLabel(const std::string& labelValue, LabelInfo& labelInfo)
IDataReader::LabelIdType SequenceReader<ElemType>::GetIdFromLabel(const std::string& labelValue, LabelInfo& labelInfo)
{
auto found = labelInfo.mapLabelToId.find(labelValue);
// not yet found, add to the map
@ -701,11 +702,11 @@ void SequenceReader<ElemType>::InitCache(const ConfigParameters& readerConfig)
// mmodify the config so the reader types look correct
config["readerType"] = config("writerType");
config["file"] = filesList;
m_cachingReader = new DataReader<ElemType>(config);
m_cachingReader = new DataReader(config);
}
else
{
m_cachingWriter = new DataWriter<ElemType>(readerConfig);
m_cachingWriter = new DataWriter(readerConfig);
// now get the section names for map and category types
std::map<std::wstring, SectionType, nocase_compare> sections;
@ -877,8 +878,8 @@ void SequenceReader<ElemType>::StartMinibatchLoop(size_t mbSize, size_t epoch, s
{
m_labelsBuffer = new ElemType[mbSize * labelInfo.dim];
memset(m_labelsBuffer, 0, sizeof(ElemType) * mbSize * labelInfo.dim);
m_labelsIdBuffer = new typename IDataReader<ElemType>::LabelIdType[mbSize];
memset(m_labelsIdBuffer, 0, sizeof(typename IDataReader<ElemType>::LabelIdType) * mbSize);
m_labelsIdBuffer = new LabelIdType[mbSize];
memset(m_labelsIdBuffer, 0, sizeof(LabelIdType) * mbSize);
}
else if (labelInfo.type != labelNone)
{
@ -1188,7 +1189,7 @@ bool SequenceReader<ElemType>::GetMinibatch(StreamMinibatchInputs& matrices)
if (labelInfo.type == labelCategory)
{
memset(m_labelsBuffer, 0, sizeof(ElemType) * labelInfo.dim * actualmbsize);
memset(m_labelsIdBuffer, 0, sizeof(typename IDataReader<ElemType>::LabelIdType) * actualmbsize);
memset(m_labelsIdBuffer, 0, sizeof(LabelIdType) * actualmbsize);
}
else if (labelInfo.type != labelNone)
{
@ -1307,7 +1308,7 @@ bool SequenceReader<ElemType>::GetMinibatch(StreamMinibatchInputs& matrices)
// GetLabelMapping - Gets the label mapping from integer index to label type
// returns - a map from numeric datatype to native label type
template <class ElemType>
const std::map<typename IDataReader<ElemType>::LabelIdType, typename IDataReader<ElemType>::LabelType>& SequenceReader<ElemType>::GetLabelMapping(const std::wstring& sectionName)
const std::map<IDataReader::LabelIdType, IDataReader::LabelType>& SequenceReader<ElemType>::GetLabelMapping(const std::wstring& sectionName)
{
FailBecauseDeprecated(__FUNCTION__); // DEPRECATED CLASS, SHOULD NOT BE USED ANYMORE
@ -1324,7 +1325,7 @@ const std::map<typename IDataReader<ElemType>::LabelIdType, typename IDataReader
// labelMapping - mapping table from label values to IDs (must be 0-n)
// note: for tasks with labels, the mapping table must be the same between a training run and a testing run
template <class ElemType>
void SequenceReader<ElemType>::SetLabelMapping(const std::wstring& /*sectionName*/, const std::map<typename IDataReader<ElemType>::LabelIdType, LabelType>& labelMapping)
void SequenceReader<ElemType>::SetLabelMapping(const std::wstring& /*sectionName*/, const std::map<IDataReader::LabelIdType, LabelType>& labelMapping)
{
FailBecauseDeprecated(__FUNCTION__); // DEPRECATED CLASS, SHOULD NOT BE USED ANYMORE

Просмотреть файл

@ -109,9 +109,8 @@ public:
// Note: This class is deprecated for standalone use, only used as a base for BatchSequenceReader which overrides most of the functions.
template <class ElemType>
class SequenceReader : public IDataReader<ElemType>
class SequenceReader : public IDataReader
{
typedef IDataReader<ElemType> Base;
protected:
bool m_idx2clsRead;
bool m_clsinfoRead;
@ -119,9 +118,6 @@ protected:
bool m_idx2probRead;
public:
using LabelType = typename Base::LabelType;
using LabelIdType = typename Base::LabelIdType;
map<string, int> word4idx;
map<int, string> idx4word;
map<int, int> idx4class;
@ -211,8 +207,8 @@ protected:
} m_labelInfo[labelInfoNum];
// caching support
DataReader<ElemType>* m_cachingReader;
DataWriter<ElemType>* m_cachingWriter;
DataReader* m_cachingReader;
DataWriter* m_cachingWriter;
ConfigParameters m_readerConfig;
void InitCache(const ConfigParameters& config);

Просмотреть файл

@ -12,7 +12,7 @@
namespace Microsoft { namespace MSR { namespace CNTK {
template <class ElemType>
class LMSequenceWriter : public IDataWriter<ElemType>
class LMSequenceWriter : public IDataWriter
{
private:
std::vector<size_t> outputDims;
@ -49,8 +49,6 @@ public:
}
public:
using LabelType = typename IDataWriter<ElemType>::LabelType;
using LabelIdType = typename IDataWriter<ElemType>::LabelIdType;
void GetSections(std::map<std::wstring, SectionType, nocase_compare>& /*sections*/)
{
}
@ -77,20 +75,4 @@ public:
};
};
template <class ElemType>
void DATAWRITER_API GetWriter(IDataWriter<ElemType>** pwriter)
{
assert(pwriter != nullptr);
*pwriter = new LMSequenceWriter<ElemType>();
assert(*pwriter != nullptr);
}
extern "C" DATAWRITER_API void GetWriterF(IDataWriter<float>** pwriter)
{
GetWriter(pwriter);
}
extern "C" DATAWRITER_API void GetWriterD(IDataWriter<double>** pwriter)
{
GetWriter(pwriter);
}
} } }
}}}

Просмотреть файл

@ -4,6 +4,8 @@
//
// DataWriter.cpp : Defines the exported functions for the DLL application.
//
// TODO: Unify with shared DataWriter.cpp.
//
#include "stdafx.h"
#include "Basics.h"
@ -14,59 +16,45 @@
namespace Microsoft { namespace MSR { namespace CNTK {
template <class ElemType>
void DATAWRITER_API GetWriter(IDataWriter<ElemType>** pwriter)
{
*pwriter = new LUSequenceWriter<ElemType>();
}
extern "C" DATAWRITER_API void GetWriterF(IDataWriter<float>** pwriter)
{
GetWriter(pwriter);
}
extern "C" DATAWRITER_API void GetWriterD(IDataWriter<double>** pwriter)
{
GetWriter(pwriter);
}
template <class ElemType>
template <class ConfigRecordType>
void DataWriter<ElemType>::InitFromConfig(const ConfigRecordType& writerConfig)
void DataWriter::InitFromConfig(const ConfigRecordType& writerConfig)
{
m_dataWriter = new LUSequenceWriter<ElemType>();
wstring precision = writerConfig(L"precision", L"float");
if (precision == L"float")
m_dataWriter = new LUSequenceWriter<float>();
else if (precision == L"double")
m_dataWriter = new LUSequenceWriter<double>();
else
InvalidArgument("DataWriter (LUSequenceWriter): The 'precision' parameter must be 'float' or 'double'.");
m_dataWriter->Init(writerConfig);
}
// Destroy - cleanup and remove this class
// NOTE: this destroys the object, and it can't be used past this point
template <class ElemType>
void DataWriter<ElemType>::Destroy()
void DataWriter::Destroy()
{
delete m_dataWriter;
m_dataWriter = NULL;
// TODO: don't we need to destroy ourselves?
}
// DataWriter Constructor
// config - [in] configuration data for the data writer
template <class ElemType>
template <class ConfigRecordType>
DataWriter<ElemType>::DataWriter(const ConfigRecordType& config)
DataWriter::DataWriter(const ConfigRecordType& config)
{
Init(config);
}
// destructor - cleanup temp files, etc.
template <class ElemType>
DataWriter<ElemType>::~DataWriter()
DataWriter::~DataWriter()
{
delete m_dataWriter;
m_dataWriter = NULL;
Destroy();
}
// GetSections - Get the sections of the file
// sections - a map of section name to section. Data sepcifications from config file will be used to determine where and how to save data
template <class ElemType>
void DataWriter<ElemType>::GetSections(std::map<std::wstring, SectionType, nocase_compare>& sections)
void DataWriter::GetSections(std::map<std::wstring, SectionType, nocase_compare>& sections)
{
m_dataWriter->GetSections(sections);
}
@ -77,8 +65,7 @@ void DataWriter<ElemType>::GetSections(std::map<std::wstring, SectionType, nocas
// numRecords - number of records we are saving, can be zero if not applicable
// datasetSize - Size of the dataset
// byteVariableSized - for variable sized data, size of current block to be written, zero when not used, or ignored if not variable sized data
template <class ElemType>
bool DataWriter<ElemType>::SaveData(size_t recordStart, const std::map<std::wstring, void*, nocase_compare>& matrices, size_t numRecords, size_t datasetSize, size_t byteVariableSized)
bool DataWriter::SaveData(size_t recordStart, const std::map<std::wstring, void*, nocase_compare>& matrices, size_t numRecords, size_t datasetSize, size_t byteVariableSized)
{
return m_dataWriter->SaveData(recordStart, matrices, numRecords, datasetSize, byteVariableSized);
}
@ -86,13 +73,9 @@ bool DataWriter<ElemType>::SaveData(size_t recordStart, const std::map<std::wstr
// SaveMapping - save a map into the file
// saveId - name of the section to save into (section:subsection format)
// labelMapping - map we are saving to the file
template <class ElemType>
void DataWriter<ElemType>::SaveMapping(std::wstring saveId, const std::map<typename LabelIdType, typename LabelType>& labelMapping)
void DataWriter::SaveMapping(std::wstring saveId, const std::map<LabelIdType, LabelType>& labelMapping)
{
m_dataWriter->SaveMapping(saveId, labelMapping);
}
//The explicit instantiation
template class DataWriter<double>;
template class DataWriter<float>;
} } }
}}}

Просмотреть файл

@ -7,25 +7,28 @@
#include "stdafx.h"
#define DATAREADER_EXPORTS
#include "DataReader.h"
#define DATAWRITER_EXPORTS
#include "LUSequenceReader.h"
#include "LUSequenceWriter.h"
namespace Microsoft { namespace MSR { namespace CNTK {
template <class ElemType>
void DATAREADER_API GetReader(IDataReader<ElemType>** preader)
extern "C" DATAREADER_API void GetReaderF(IDataReader** preader)
{
// *preader = new LUSequenceReader<ElemType>();
// *preader = new BatchLUSequenceReader<ElemType>();
*preader = new MultiIOBatchLUSequenceReader<ElemType>();
*preader = new MultiIOBatchLUSequenceReader<float>();
}
extern "C" DATAREADER_API void GetReaderD(IDataReader** preader)
{
*preader = new MultiIOBatchLUSequenceReader<double>();
}
extern "C" DATAREADER_API void GetReaderF(IDataReader<float>** preader)
extern "C" DATAWRITER_API void GetWriterF(IDataWriter** pwriter)
{
GetReader(preader);
*pwriter = new LUSequenceWriter<float>();
}
extern "C" DATAREADER_API void GetReaderD(IDataReader<double>** preader)
extern "C" DATAWRITER_API void GetWriterD(IDataWriter** pwriter)
{
GetReader(preader);
*pwriter = new LUSequenceWriter<double>();
}
} } }
}}}

Просмотреть файл

@ -5,6 +5,8 @@
// LUSequenceParser.h : Parses the UCI format using a custom state machine (for speed)
//
#pragma once
#include <string>
#include <vector>
#include <assert.h>

Просмотреть файл

@ -47,7 +47,7 @@ enum ReaderMode
};
template <class ElemType>
class LUSequenceReader : public IDataReader<ElemType>
class LUSequenceReader : public IDataReader
{
protected:
bool m_idx2clsRead;
@ -149,8 +149,8 @@ protected:
} m_labelInfo[labelInfoNum];
// caching support
DataReader<ElemType>* m_cachingReader;
DataWriter<ElemType>* m_cachingWriter;
DataReader* m_cachingReader;
DataWriter* m_cachingWriter;
ConfigParameters m_readerConfig;
void InitCache(const ConfigParameters& config);

Просмотреть файл

@ -120,7 +120,7 @@
<ClCompile Include="..\..\Common\Config.cpp">
<PrecompiledHeader>NotUsing</PrecompiledHeader>
</ClCompile>
<ClCompile Include="DataWriter.cpp" />
<ClCompile Include="DataWriterLocal.cpp" />
<ClCompile Include="Exports.cpp" />
<ClCompile Include="dllmain.cpp">
<CompileAsManaged Condition="$(DebugBuild)">false</CompileAsManaged>

Просмотреть файл

@ -10,9 +10,6 @@
<ClCompile Include="..\..\Common\DataReader.cpp">
<Filter>Common</Filter>
</ClCompile>
<ClCompile Include="DataWriter.cpp">
<Filter>Common</Filter>
</ClCompile>
<ClCompile Include="..\..\Common\fileutil.cpp">
<Filter>Common</Filter>
</ClCompile>
@ -21,6 +18,9 @@
</ClCompile>
<ClCompile Include="..\..\Common\Config.cpp" />
<ClCompile Include="..\..\Common\ExceptionWithCallStack.cpp" />
<ClCompile Include="DataWriterLocal.cpp">
<Filter>Common</Filter>
</ClCompile>
</ItemGroup>
<ItemGroup>
<ClInclude Include="LUSequenceWriter.h" />
@ -51,8 +51,5 @@
<Filter Include="Common\Include">
<UniqueIdentifier>{85d2fa50-2b95-4ec7-8f2c-c5c0b1cb493e}</UniqueIdentifier>
</Filter>
<Filter Include="Duplicates to remove">
<UniqueIdentifier>{10f819fb-8861-4607-9389-60ca80f968c2}</UniqueIdentifier>
</Filter>
</ItemGroup>
</Project>

Просмотреть файл

@ -2,17 +2,21 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
#pragma once
#include "DataWriter.h"
#include "LUSequenceParser.h"
#include <stdio.h>
#ifndef MAX_STRING
#define MAX_STRING 2048
#endif
namespace Microsoft { namespace MSR { namespace CNTK {
template <class ElemType>
class LUSequenceWriter : public IDataWriter<ElemType>
class LUSequenceWriter : public IDataWriter
{
private:
std::vector<size_t> outputDims;
@ -42,7 +46,7 @@ public:
void GetSections(std::map<std::wstring, SectionType, nocase_compare>& /*sections*/)
{
}
void SaveMapping(std::wstring saveId, const std::map<typename LabelIdType, typename LabelType>& /*labelMapping*/)
void SaveMapping(std::wstring saveId, const std::map<LabelIdType, LabelType>& /*labelMapping*/)
{
}

Просмотреть файл

@ -12,18 +12,13 @@
namespace Microsoft { namespace MSR { namespace CNTK {
template <class ElemType>
void DATAREADER_API GetReader(IDataReader<ElemType>** preader)
extern "C" DATAREADER_API void GetReaderF(IDataReader** preader)
{
*preader = new LibSVMBinaryReader<ElemType>();
*preader = new LibSVMBinaryReader<float>();
}
extern "C" DATAREADER_API void GetReaderD(IDataReader** preader)
{
*preader = new LibSVMBinaryReader<double>();
}
extern "C" DATAREADER_API void GetReaderF(IDataReader<float>** preader)
{
GetReader(preader);
}
extern "C" DATAREADER_API void GetReaderD(IDataReader<double>** preader)
{
GetReader(preader);
}
} } }
}}}

Просмотреть файл

@ -226,12 +226,9 @@ private:
};
template <class ElemType>
class LibSVMBinaryReader : public IDataReader<ElemType>
class LibSVMBinaryReader : public IDataReader
{
public:
using LabelType = typename IDataReader<ElemType>::LabelType;
using LabelIdType = typename IDataReader<ElemType>::LabelIdType;
virtual void Init(const ConfigParameters& config) override
{
InitFromConfig(config);

Просмотреть файл

@ -19,7 +19,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
typedef ReaderPtr (*ReaderFactory)(const ConfigParameters& parameters);
template <class ElemType>
class ReaderShim : public IDataReader<ElemType>
class ReaderShim : public IDataReader
{
public:
explicit ReaderShim(ReaderFactory factory);

Просмотреть файл

@ -12,18 +12,13 @@
namespace Microsoft { namespace MSR { namespace CNTK {
template <class ElemType>
void DATAREADER_API GetReader(IDataReader<ElemType>** preader)
extern "C" DATAREADER_API void GetReaderF(IDataReader** preader)
{
*preader = new SparsePCReader<ElemType>();
*preader = new SparsePCReader<float>();
}
extern "C" DATAREADER_API void GetReaderD(IDataReader** preader)
{
*preader = new SparsePCReader<double>();
}
extern "C" DATAREADER_API void GetReaderF(IDataReader<float>** preader)
{
GetReader(preader);
}
extern "C" DATAREADER_API void GetReaderD(IDataReader<double>** preader)
{
GetReader(preader);
}
} } }
}}}

Просмотреть файл

@ -317,7 +317,7 @@ bool SparsePCReader<ElemType>::DataEnd() { return true; }
// GetLabelMapping - Gets the label mapping from integer index to label type
// returns - a map from numeric datatype to native label type
template <class ElemType>
const std::map<typename IDataReader<ElemType>::LabelIdType, typename IDataReader<ElemType>::LabelType>& SparsePCReader<ElemType>::GetLabelMapping(const std::wstring& /*sectionName*/)
const std::map<IDataReader::LabelIdType, IDataReader::LabelType>& SparsePCReader<ElemType>::GetLabelMapping(const std::wstring& /*sectionName*/)
{
return m_mapIdToLabel;
}
@ -326,7 +326,7 @@ const std::map<typename IDataReader<ElemType>::LabelIdType, typename IDataReader
// labelMapping - mapping table from label values to IDs (must be 0-n)
// note: for tasks with labels, the mapping table must be the same between a training run and a testing run
template <class ElemType>
void SparsePCReader<ElemType>::SetLabelMapping(const std::wstring& /*sectionName*/, const std::map<typename IDataReader<ElemType>::LabelIdType, LabelType>& labelMapping)
void SparsePCReader<ElemType>::SetLabelMapping(const std::wstring& /*sectionName*/, const std::map<IDataReader::LabelIdType, LabelType>& labelMapping)
{
m_mapIdToLabel = labelMapping;
m_mapLabelToId.clear();

Просмотреть файл

@ -21,12 +21,8 @@
namespace Microsoft { namespace MSR { namespace CNTK {
template <class ElemType>
class SparsePCReader : public IDataReader<ElemType>
class SparsePCReader : public IDataReader
{
public:
using LabelType = typename IDataReader<ElemType>::LabelType;
using LabelIdType = typename IDataReader<ElemType>::LabelIdType;
private:
ConfigParameters m_readerConfig;
std::wstring m_file;
size_t m_featureCount;

Просмотреть файл

@ -12,18 +12,13 @@
namespace Microsoft { namespace MSR { namespace CNTK {
template <class ElemType>
void DATAREADER_API GetReader(IDataReader<ElemType>** preader)
extern "C" DATAREADER_API void GetReaderF(IDataReader** preader)
{
*preader = new UCIFastReader<ElemType>();
*preader = new UCIFastReader<float>();
}
extern "C" DATAREADER_API void GetReaderD(IDataReader** preader)
{
*preader = new UCIFastReader<double>();
}
extern "C" DATAREADER_API void GetReaderF(IDataReader<float>** preader)
{
GetReader(preader);
}
extern "C" DATAREADER_API void GetReaderD(IDataReader<double>** preader)
{
GetReader(preader);
}
} } }
}}}

Просмотреть файл

@ -500,11 +500,11 @@ void UCIFastReader<ElemType>::InitCache(const ConfigParameters& readerConfig)
// mmodify the config so the reader types look correct
config["readerType"] = config("writerType");
config["file"] = filesList;
m_cachingReader = new DataReader<ElemType>(config);
m_cachingReader = new DataReader(config);
}
else
{
m_cachingWriter = new DataWriter<ElemType>(readerConfig);
m_cachingWriter = new DataWriter(readerConfig);
// now get the section names for map and category types
std::map<std::wstring, SectionType, nocase_compare> sections;
@ -1015,7 +1015,7 @@ bool UCIFastReader<ElemType>::GetMinibatchImpl(StreamMinibatchInputs& matrices)
// GetLabelMapping - Gets the label mapping from integer index to label type
// returns - a map from numeric datatype to native label type
template <class ElemType>
const std::map<typename IDataReader<ElemType>::LabelIdType, typename IDataReader<ElemType>::LabelType>& UCIFastReader<ElemType>::GetLabelMapping(const std::wstring& sectionName)
const std::map<IDataReader::LabelIdType, IDataReader::LabelType>& UCIFastReader<ElemType>::GetLabelMapping(const std::wstring& sectionName)
{
if (m_cachingReader)
{
@ -1028,7 +1028,7 @@ const std::map<typename IDataReader<ElemType>::LabelIdType, typename IDataReader
// labelMapping - mapping table from label values to IDs (must be 0-n)
// note: for tasks with labels, the mapping table must be the same between a training run and a testing run
template <class ElemType>
void UCIFastReader<ElemType>::SetLabelMapping(const std::wstring& /*sectionName*/, const std::map<typename IDataReader<ElemType>::LabelIdType, LabelType>& labelMapping)
void UCIFastReader<ElemType>::SetLabelMapping(const std::wstring& /*sectionName*/, const std::map<LabelIdType, LabelType>& labelMapping)
{
if (m_cachingReader)
{

Просмотреть файл

@ -36,15 +36,8 @@ enum LabelKind
};
template <class ElemType>
class UCIFastReader : public IDataReader<ElemType>
class UCIFastReader : public IDataReader
{
public:
using LabelType = typename IDataReader<ElemType>::LabelType;
using LabelIdType = typename IDataReader<ElemType>::LabelIdType;
using IDataReader<ElemType>::mRequestedNumParallelSequences;
// typedef std::string LabelType;
// typedef unsigned LabelIdType;
private:
shared_ptr<UCIParser<ElemType, LabelType>> m_parser;
size_t m_mbSize; // size of minibatch requested
LabelIdType m_labelIdMax; // maximum label ID we have encountered so far
@ -102,8 +95,8 @@ private:
unique_ptr<CUDAPageLockedMemAllocator> m_cudaAllocator;
// caching support
DataReader<ElemType>* m_cachingReader;
DataWriter<ElemType>* m_cachingWriter;
DataReader* m_cachingReader;
DataWriter* m_cachingWriter;
ConfigParameters m_readerConfig;
void InitCache(const ConfigParameters& config);
void InitCache(const ScriptableObjects::IConfigRecord& config);

Просмотреть файл

@ -24,7 +24,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// Note: This will go away with the redesigned reader interface.
// TODO: callers of this often do ComputationNetwork::BumpEvalTimeStamp(featureNodes) and also for labels; we should eliminate the need for this.
template <class ElemType>
static bool GetMinibatchIntoNetwork(IDataReader<ElemType>& trainSetDataReader,
static bool GetMinibatchIntoNetwork(IDataReader& trainSetDataReader,
ComputationNetworkPtr net,
ComputationNodeBasePtr criterionNode,
bool useDistributedMBReading,
@ -345,7 +345,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
}
}
size_t GetMinibatchIntoCache(IDataReader<ElemType>& trainSetDataReader,
size_t GetMinibatchIntoCache(IDataReader& trainSetDataReader,
ComputationNetwork& net,
StreamMinibatchInputs& inputMatrices,
size_t requestedSubminibatches)

Просмотреть файл

@ -36,8 +36,8 @@ template SGD<double>::SGD(const ScriptableObjects::IConfigRecord&);
template <class ElemType>
void SGD<ElemType>::Train(function<ComputationNetworkPtr(DEVICEID_TYPE)> createNetworkFn, DEVICEID_TYPE deviceId,
IDataReader<ElemType>* trainSetDataReader,
IDataReader<ElemType>* validationSetDataReader,
IDataReader* trainSetDataReader,
IDataReader* validationSetDataReader,
const bool makeMode)
{
// determine which epoch to start with, including recovering a checkpoint if any and 'makeMode' enabled
@ -80,8 +80,8 @@ void SGD<ElemType>::Train(function<ComputationNetworkPtr(DEVICEID_TYPE)> createN
template <class ElemType>
void SGD<ElemType>::Adapt(wstring origModelFileName, wstring refNodeName,
IDataReader<ElemType>* trainSetDataReader,
IDataReader<ElemType>* validationSetDataReader,
IDataReader* trainSetDataReader,
IDataReader* validationSetDataReader,
const DEVICEID_TYPE deviceId, const bool makeMode)
{
int startEpoch = DetermineStartEpoch(makeMode);
@ -139,8 +139,8 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
bool networkLoadedFromCheckpoint,
ComputationNetworkPtr refNet,
ComputationNodeBasePtr refNode,
IDataReader<ElemType>* trainSetDataReader,
IDataReader<ElemType>* validationSetDataReader)
IDataReader* trainSetDataReader,
IDataReader* validationSetDataReader)
{
auto& featureNodes = net->FeatureNodes();
auto& labelNodes = net->LabelNodes();
@ -188,7 +188,7 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
{
auto& nodes = (pass == 0) ? featureNodes : labelNodes;
for (const auto & node : nodes)
(*inputMatrices).AddInputMatrix(node->NodeName(), dynamic_pointer_cast<ComputationNode<ElemType>>(node)->ValuePtr());
(*inputMatrices).AddInputMatrix(node->NodeName(), node->ValuePtr());
}
// get hmm file for sequence training
@ -699,7 +699,7 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
const ComputationNodeBasePtr& refNode,
const int epochNumber,
const size_t epochSize,
IDataReader<ElemType>* trainSetDataReader,
IDataReader* trainSetDataReader,
const double learnRatePerSample,
size_t tunedMBSize,
const std::vector<ComputationNodeBasePtr>& featureNodes,
@ -843,7 +843,7 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
// get minibatch
// TODO: is it guaranteed that the GPU is already completed at this point, is it safe to overwrite the buffers?
size_t actualMBSize = 0;
bool wasDataRead = DataReaderHelpers::GetMinibatchIntoNetwork(*trainSetDataReader, net, criterionNodes[0],
bool wasDataRead = DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*trainSetDataReader, net, criterionNodes[0],
useDistributedMBReading, useParallelTrain, *inputMatrices, actualMBSize);
if (!wasDataRead && (!useDistributedMBReading || noMoreSamplesToProcess)) // in case of distributed reading, we do a few more loops until all ranks have completed
break; // end of epoch
@ -1277,7 +1277,7 @@ std::vector<ComputationNodeBasePtr>& SGD<ElemType>::GetEvalCriterionNodes(Comput
// Returns true if precomputation was executed.
template <class ElemType>
bool SGD<ElemType>::PreCompute(ComputationNetworkPtr net,
IDataReader<ElemType>* trainSetDataReader,
IDataReader* trainSetDataReader,
std::vector<ComputationNodeBasePtr>& featureNodes,
std::vector<ComputationNodeBasePtr>& labelNodes,
StreamMinibatchInputs* inputMatrices)
@ -1312,7 +1312,7 @@ bool SGD<ElemType>::PreCompute(ComputationNetworkPtr net,
const size_t numIterationsBeforePrintingProgress = 100;
size_t numItersSinceLastPrintOfProgress = 0;
size_t actualMBSizeDummy;
while (DataReaderHelpers::GetMinibatchIntoNetwork(*trainSetDataReader, net, nullptr, false, false, *inputMatrices, actualMBSizeDummy))
while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*trainSetDataReader, net, nullptr, false, false, *inputMatrices, actualMBSizeDummy))
{
// TODO: move these into GetMinibatchIntoNetwork() --but those are passed around; necessary? Can't we get them from 'net'?
ComputationNetwork::BumpEvalTimeStamp(featureNodes);
@ -1347,7 +1347,7 @@ double SGD<ElemType>::SearchForBestLearnRate(ComputationNetworkPtr net,
ComputationNetworkPtr refNet,
const ComputationNodeBasePtr& refNode, const int epochNumber,
const double curLearnRate,
IDataReader<ElemType>* trainSetDataReader,
IDataReader* trainSetDataReader,
const std::vector<ComputationNodeBasePtr>& featureNodes,
const std::vector<ComputationNodeBasePtr>& labelNodes,
const std::vector<ComputationNodeBasePtr>& criterionNodes,
@ -1514,7 +1514,7 @@ size_t SGD<ElemType>::AdaptiveMinibatchSizing(ComputationNetworkPtr net,
const ComputationNodeBasePtr& refNode,
const int epochNumber,
const size_t numFramesToUseInSearch,
IDataReader<ElemType>* trainSetDataReader,
IDataReader* trainSetDataReader,
const double learnRatePerSample,
const size_t initialMinibatchSize,
const std::vector<ComputationNodeBasePtr>& featureNodes,
@ -1620,7 +1620,7 @@ size_t SGD<ElemType>::SearchForBestMinibatchSize(ComputationNetworkPtr net,
const ComputationNodeBasePtr& refNode,
const int epochNumber,
const size_t numFramesToUseInSearch,
IDataReader<ElemType>* trainSetDataReader,
IDataReader* trainSetDataReader,
const double learnRatePerSample,
const std::vector<ComputationNodeBasePtr>& featureNodes,
const std::vector<ComputationNodeBasePtr>& labelNodes,
@ -1716,7 +1716,7 @@ template <class ElemType>
void SGD<ElemType>::TrainOneMiniEpochAndReloadModel(ComputationNetworkPtr net,
ComputationNetworkPtr refNet,
const ComputationNodeBasePtr& refNode, const int epochNumber,
const size_t epochSize, IDataReader<ElemType>* trainSetDataReader,
const size_t epochSize, IDataReader* trainSetDataReader,
const double learnRatePerSample,
const size_t minibatchSize,
const std::vector<ComputationNodeBasePtr>& featureNodes,
@ -1773,7 +1773,7 @@ void SGD<ElemType>::TrainOneMiniEpochAndReloadModel(ComputationNetworkPtr net,
// TODO: move the two-forward-pass support out of the reader.
template <class ElemType>
void SGD<ElemType>::AttemptUtteranceDerivativeFeatures(ComputationNetworkPtr net,
IDataReader<ElemType>* trainSetDataReader,
IDataReader* trainSetDataReader,
const std::vector<ComputationNodeBasePtr>& featureNodes,
StreamMinibatchInputs* inputMatrices)
{

Просмотреть файл

@ -296,12 +296,12 @@ public:
}
void Train(function<ComputationNetworkPtr(DEVICEID_TYPE)> createNetworkFn, DEVICEID_TYPE deviceId,
IDataReader<ElemType>* trainSetDataReader,
IDataReader<ElemType>* validationSetDataReader,
IDataReader* trainSetDataReader,
IDataReader* validationSetDataReader,
const bool makeMode = true);
void Adapt(wstring origModelFileName, wstring refNodeName,
IDataReader<ElemType>* trainSetDataReader,
IDataReader<ElemType>* validationSetDataReader,
IDataReader* trainSetDataReader,
IDataReader* validationSetDataReader,
const DEVICEID_TYPE deviceID, const bool makeMode = true);
protected:
@ -313,14 +313,14 @@ protected:
bool networkLoadedFromCheckpoint,
ComputationNetworkPtr refNet,
ComputationNodeBasePtr refNode,
IDataReader<ElemType>* trainSetDataReader,
IDataReader<ElemType>* validationSetDataReader);
IDataReader* trainSetDataReader,
IDataReader* validationSetDataReader);
protected:
// return true if precomputation is executed.
bool PreCompute(ComputationNetworkPtr net,
IDataReader<ElemType>* trainSetDataReader,
IDataReader* trainSetDataReader,
std::vector<ComputationNodeBasePtr>& featureNodes,
std::vector<ComputationNodeBasePtr>& labelNodes,
StreamMinibatchInputs* inputMatrices);
@ -330,7 +330,7 @@ protected:
ComputationNetworkPtr refNet,
const ComputationNodeBasePtr& refNode, const int epochNumber,
const double curLearnRate,
IDataReader<ElemType>* trainSetDataReader,
IDataReader* trainSetDataReader,
const std::vector<ComputationNodeBasePtr>& featureNodes,
const std::vector<ComputationNodeBasePtr>& labelNodes,
const std::vector<ComputationNodeBasePtr>& criterionNodes,
@ -344,7 +344,7 @@ protected:
void TrainOneMiniEpochAndReloadModel(ComputationNetworkPtr net,
ComputationNetworkPtr refNet,
const ComputationNodeBasePtr& refNode, const int epochNumber,
const size_t epochSize, IDataReader<ElemType>* trainSetDataReader,
const size_t epochSize, IDataReader* trainSetDataReader,
const double learnRatePerSample,
const size_t minibatchSize,
const std::vector<ComputationNodeBasePtr>& featureNodes,
@ -364,7 +364,7 @@ protected:
const ComputationNodeBasePtr& refNode,
const int epochNumber,
const size_t numFramesToUseInSearch,
IDataReader<ElemType>* trainSetDataReader,
IDataReader* trainSetDataReader,
const double learnRatePerSample,
const size_t initialMinibatchSize,
const std::vector<ComputationNodeBasePtr>& featureNodes,
@ -383,7 +383,7 @@ protected:
const ComputationNodeBasePtr& refNode,
const int epochNumber,
const size_t numFramesToUseInSearch,
IDataReader<ElemType>* trainSetDataReader,
IDataReader* trainSetDataReader,
const double learnRatePerSample,
const std::vector<ComputationNodeBasePtr>& featureNodes,
const std::vector<ComputationNodeBasePtr>& labelNodes,
@ -400,7 +400,7 @@ protected:
// processing more utterances at the same time. Only used in Kaldi2Reader.
// TODO: move the two-forward-pass support out of the reader.
void AttemptUtteranceDerivativeFeatures(ComputationNetworkPtr net,
IDataReader<ElemType>* trainSetDataReader,
IDataReader* trainSetDataReader,
const std::vector<ComputationNodeBasePtr>& featureNodes,
StreamMinibatchInputs* inputMatrices);
@ -409,7 +409,7 @@ protected:
const ComputationNodeBasePtr& refNode,
const int epochNumber,
const size_t epochSize,
IDataReader<ElemType>* trainSetDataReader,
IDataReader* trainSetDataReader,
const double learnRatePerSample,
size_t tunedMBSize,
const std::vector<ComputationNodeBasePtr>& featureNodes,

Просмотреть файл

@ -30,7 +30,7 @@ public:
}
// returns evaluation node values per sample determined by evalNodeNames (which can include both training and eval criterion nodes)
vector<double> Evaluate(IDataReader<ElemType>* dataReader, const vector<wstring>& evalNodeNames, const size_t mbSize, const size_t testSize = requestDataSize)
vector<double> Evaluate(IDataReader* dataReader, const vector<wstring>& evalNodeNames, const size_t mbSize, const size_t testSize = requestDataSize)
{
// determine nodes to evaluate
std::vector<ComputationNodeBasePtr> evalNodes;
@ -77,9 +77,9 @@ public:
StreamMinibatchInputs inputMatrices;
for (auto& node : featureNodes)
inputMatrices.AddInputMatrix(node->NodeName(), node->As<ComputationNode<ElemType>>()->ValuePtr());
inputMatrices.AddInputMatrix(node->NodeName(), node->ValuePtr());
for (auto& node : labelNodes)
inputMatrices.AddInputMatrix(node->NodeName(), node->As<ComputationNode<ElemType>>()->ValuePtr());
inputMatrices.AddInputMatrix(node->NodeName(), node->ValuePtr());
// evaluate through minibatches
size_t totalEpochSamples = 0;
@ -95,7 +95,7 @@ public:
dataReader->StartMinibatchLoop(mbSize, 0, testSize);
m_net->StartEvaluateMinibatchLoop(evalNodes);
while (DataReaderHelpers::GetMinibatchIntoNetwork(*dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize))
while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize))
{
ComputationNetwork::BumpEvalTimeStamp(featureNodes);
ComputationNetwork::BumpEvalTimeStamp(labelNodes);

Просмотреть файл

@ -74,7 +74,7 @@ private:
{
StreamMinibatchInputs inputMatrices;
for (auto& node : inputNodes)
inputMatrices.AddInputMatrix(node->NodeName(), node->As<ComputationNode<ElemType>>()->ValuePtr());
inputMatrices.AddInputMatrix(node->NodeName(), node->ValuePtr());
return inputMatrices;
}
@ -84,7 +84,7 @@ public:
{
}
void WriteOutput(IDataReader<ElemType>& dataReader, size_t mbSize, IDataWriter<ElemType>& dataWriter, const std::vector<std::wstring>& outputNodeNames, size_t numOutputSamples = requestDataSize, bool doUnitTest = false)
void WriteOutput(IDataReader& dataReader, size_t mbSize, IDataWriter& dataWriter, const std::vector<std::wstring>& outputNodeNames, size_t numOutputSamples = requestDataSize, bool doUnitTest = false)
{
std::vector<ComputationNodeBasePtr> outputNodes = DetermineOutputNodes(outputNodeNames);
std::vector<ComputationNodeBasePtr> inputNodes = DetermineInputNodes(outputNodes);
@ -104,7 +104,7 @@ public:
std::map<std::wstring, void*, nocase_compare> outputMatrices;
size_t actualMBSize;
while (DataReaderHelpers::GetMinibatchIntoNetwork(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize))
while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize))
{
ComputationNetwork::BumpEvalTimeStamp(inputNodes);
@ -174,7 +174,7 @@ public:
};
// TODO: Remove code dup with above function by creating a fake Writer object and then calling the other function.
void WriteOutput(IDataReader<ElemType>& dataReader, size_t mbSize, std::wstring outputPath, const std::vector<std::wstring>& outputNodeNames, const WriteFormattingOptions & formattingOptions, size_t numOutputSamples = requestDataSize)
void WriteOutput(IDataReader& dataReader, size_t mbSize, std::wstring outputPath, const std::vector<std::wstring>& outputNodeNames, const WriteFormattingOptions & formattingOptions, size_t numOutputSamples = requestDataSize)
{
std::vector<ComputationNodeBasePtr> outputNodes = DetermineOutputNodes(outputNodeNames);
std::vector<ComputationNodeBasePtr> inputNodes = DetermineInputNodes(outputNodes);
@ -221,7 +221,7 @@ public:
std::string valueFormatString = "%" + formattingOptions.precisionFormat + formatChar; // format string used in fprintf() for formatting the values
size_t actualMBSize;
while (DataReaderHelpers::GetMinibatchIntoNetwork(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize))
while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize))
{
ComputationNetwork::BumpEvalTimeStamp(inputNodes);

Просмотреть файл

@ -48,7 +48,7 @@ void DoCommand(const ConfigParameters& configRoot)
Eval<ElemType> eval(config);
DataReader<ElemType>* dataReader = new DataReader<ElemType>(readerConfig);
auto dataReader = make_shared<DataReader>(readerConfig);
eval.LoadModel(modelPath);
dataReader->StartMinibatchLoop(mbSize, 0, epochSize);
eval.StartEvaluateMinibatchLoop(outputName);

Просмотреть файл

@ -146,7 +146,7 @@ struct ReaderFixture
template <class ElemType>
void HelperWriteReaderContentToFile(
ofstream& outputFile,
DataReader<ElemType>& dataReader,
DataReader& dataReader,
StreamMinibatchInputs& map,
size_t epochs,
size_t mbSize,
@ -225,7 +225,7 @@ struct ReaderFixture
const ConfigParameters simpleDemoConfig = config(testSectionName);
const ConfigParameters readerConfig = simpleDemoConfig(readerSectionName);
DataReader<ElemType> dataReader(readerConfig);
DataReader dataReader(readerConfig);
StreamMinibatchInputs map;
std::vector<shared_ptr<Matrix<ElemType>>> features;
@ -250,7 +250,7 @@ struct ReaderFixture
ofstream outputFile(testDataFilePath, ios::out);
// Perform the data reading
HelperWriteReaderContentToFile(outputFile, dataReader, map, epochs, mbSize, epochSize, numFeatureFiles, numLabelFiles, subsetNum, numSubsets);
HelperWriteReaderContentToFile<ElemType>(outputFile, dataReader, map, epochs, mbSize, epochSize, numFeatureFiles, numLabelFiles, subsetNum, numSubsets);
outputFile.close();
@ -286,8 +286,9 @@ struct ReaderFixture
const ConfigParameters simpleDemoConfig = config(testSectionName);
const ConfigParameters readerConfig = simpleDemoConfig(readerSectionName);
BOOST_CHECK_THROW(DataReader<ElemType> dataReader(readerConfig), ExceptionType);
BOOST_CHECK_THROW(DataReader dataReader(readerConfig), ExceptionType);
}
};
}
} } }
}}}