removed template parameter ElemType from (I)DataReader and (I)DataWriter
This commit is contained in:
Родитель
5bdc86e5de
Коммит
e54b352822
|
@ -298,12 +298,12 @@ status open
|
|||
|
||||
\begin_layout Plain Layout
|
||||
|
||||
extern "C" DATAREADER_API void GetReaderF(IDataReader<float>** preader);
|
||||
extern "C" DATAREADER_API void GetReaderF(IDataReader** preader);
|
||||
\end_layout
|
||||
|
||||
\begin_layout Plain Layout
|
||||
|
||||
extern "C" DATAREADER_API void GetReaderD(IDataReader<double>** preader);
|
||||
extern "C" DATAREADER_API void GetReaderD(IDataReader** preader);
|
||||
\end_layout
|
||||
|
||||
\end_inset
|
||||
|
@ -319,12 +319,12 @@ status open
|
|||
|
||||
\begin_layout Plain Layout
|
||||
|
||||
extern "C" DATAWRITER_API void GetWriterF(IDataWriter<float>** pwriter);
|
||||
extern "C" DATAWRITER_API void GetWriterF(IDataWriter** pwriter);
|
||||
\end_layout
|
||||
|
||||
\begin_layout Plain Layout
|
||||
|
||||
extern "C" DATAWRITER_API void GetWriterD(IDataWriter<double>** pwriter);
|
||||
extern "C" DATAWRITER_API void GetWriterD(IDataWriter** pwriter);
|
||||
\end_layout
|
||||
|
||||
\end_inset
|
||||
|
@ -440,14 +440,13 @@ ochSamples=requestDataSize) = 0;
|
|||
|
||||
\begin_layout Plain Layout
|
||||
|
||||
virtual const std::map<typename LabelIdType, typename LabelType>& GetLabelMapp
|
||||
virtual const std::map<LabelIdType, LabelType>& GetLabelMapp
|
||||
ing(const std::wstring& sectionName) = 0;
|
||||
\end_layout
|
||||
|
||||
\begin_layout Plain Layout
|
||||
|
||||
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<t
|
||||
ypename LabelIdType, typename LabelType>& labelMapping) = 0;
|
||||
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, LabelType>& labelMapping) = 0;
|
||||
\end_layout
|
||||
|
||||
\begin_layout Plain Layout
|
||||
|
@ -1097,8 +1096,7 @@ public:
|
|||
|
||||
\begin_layout Plain Layout
|
||||
|
||||
virtual void SaveMapping(std::wstring saveId, const std::map<typename
|
||||
LabelIdType, typename LabelType>& labelMapping) = 0;
|
||||
virtual void SaveMapping(std::wstring saveId, const std::map<LabelIdType, LabelType>& labelMapping) = 0;
|
||||
\end_layout
|
||||
|
||||
\begin_layout Plain Layout
|
||||
|
|
|
@ -321,7 +321,7 @@ test = [
|
|||
action = "eval"
|
||||
|
||||
# correspond to the number of words/characteres to train in a minibatch
|
||||
minibatchSize = 8192 # choose as large as memory allows for maximum concurrency
|
||||
minibatchSize = 8192 # choose as large as memory allows for maximum GPU concurrency
|
||||
# need to be small since models are updated for each minibatch
|
||||
traceLevel = 1
|
||||
epochSize = 0
|
||||
|
|
5
Makefile
5
Makefile
|
@ -307,8 +307,8 @@ $(BINARY_READER): $(BINARYREADER_OBJ) | $(CNTKMATH_LIB)
|
|||
########################################
|
||||
|
||||
HTKMLFREADER_SRC =\
|
||||
$(SOURCEDIR)/Readers/HTKMLFReader/DataReader.cpp \
|
||||
$(SOURCEDIR)/Readers/HTKMLFReader/DataWriter.cpp \
|
||||
$(SOURCEDIR)/Readers/HTKMLFReader/Exports.cpp \
|
||||
$(SOURCEDIR)/Readers/HTKMLFReader/DataWriterLocal.cpp \
|
||||
$(SOURCEDIR)/Readers/HTKMLFReader/HTKMLFReader.cpp \
|
||||
$(SOURCEDIR)/Readers/HTKMLFReader/HTKMLFWriter.cpp \
|
||||
|
||||
|
@ -348,6 +348,7 @@ $(LMSEQUENCEREADER): $(LMSEQUENCEREADER_OBJ) | $(CNTKMATH_LIB)
|
|||
|
||||
LUSEQUENCEREADER_SRC =\
|
||||
$(SOURCEDIR)/Readers/LUSequenceReader/Exports.cpp \
|
||||
$(SOURCEDIR)/Readers/LUSequenceReader/DataWriterLocal.cpp \
|
||||
$(SOURCEDIR)/Readers/LUSequenceReader/LUSequenceParser.cpp \
|
||||
$(SOURCEDIR)/Readers/LUSequenceReader/LUSequenceReader.cpp \
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ using namespace Microsoft::MSR::CNTK;
|
|||
// ===========================================================================
|
||||
|
||||
template <typename ElemType>
|
||||
static void DoEvalBase(const ConfigParameters& config, IDataReader<ElemType>& reader)
|
||||
static void DoEvalBase(const ConfigParameters& config, IDataReader& reader)
|
||||
{
|
||||
DEVICEID_TYPE deviceId = DeviceFromConfig(config);
|
||||
ConfigArray minibatchSize = config(L"minibatchSize", "40960");
|
||||
|
@ -78,9 +78,9 @@ void DoEval(const ConfigParameters& config)
|
|||
ConfigParameters readerConfig(config(L"reader"));
|
||||
readerConfig.Insert("traceLevel", config(L"traceLevel", "0"));
|
||||
|
||||
DataReader<ElemType> testDataReader(readerConfig);
|
||||
DataReader testDataReader(readerConfig);
|
||||
|
||||
DoEvalBase(config, testDataReader);
|
||||
DoEvalBase<ElemType>(config, testDataReader);
|
||||
}
|
||||
|
||||
template void DoEval<double>(const ConfigParameters& config);
|
||||
|
@ -125,7 +125,7 @@ void DoCrossValidate(const ConfigParameters& config)
|
|||
std::vector<std::vector<double>> cvErrorResults;
|
||||
std::vector<std::wstring> cvModels;
|
||||
|
||||
DataReader<ElemType> cvDataReader(readerConfig);
|
||||
DataReader cvDataReader(readerConfig);
|
||||
|
||||
bool finalModelEvaluated = false;
|
||||
for (size_t i = cvInterval[0]; i <= cvInterval[2]; i += cvInterval[1])
|
||||
|
@ -206,7 +206,7 @@ void DoWriteOutput(const ConfigParameters& config)
|
|||
readerConfig.Insert("traceLevel", config(L"traceLevel", "0"));
|
||||
readerConfig.Insert("randomize", "None"); // we don't want randomization when output results
|
||||
|
||||
DataReader<ElemType> testDataReader(readerConfig);
|
||||
DataReader testDataReader(readerConfig);
|
||||
|
||||
DEVICEID_TYPE deviceId = DeviceFromConfig(config);
|
||||
ConfigArray minibatchSize = config(L"minibatchSize", "2048");
|
||||
|
@ -244,7 +244,7 @@ void DoWriteOutput(const ConfigParameters& config)
|
|||
{
|
||||
ConfigParameters writerConfig(config(L"writer"));
|
||||
bool bWriterUnittest = writerConfig(L"unittest", "false");
|
||||
DataWriter<ElemType> testDataWriter(writerConfig);
|
||||
DataWriter testDataWriter(writerConfig);
|
||||
writer.WriteOutput(testDataReader, mbSize[0], testDataWriter, outputNodeNamesVector, epochSize, bWriterUnittest);
|
||||
}
|
||||
else if (config.Exists("outputPath"))
|
||||
|
|
|
@ -85,8 +85,7 @@ void DoCreateLabelMap(const ConfigParameters& config)
|
|||
}
|
||||
fprintf(stderr, "CreateLabelMap: Creating the mapping file '%s' \n", labelMappingFile.c_str());
|
||||
|
||||
DataReader<ElemType> dataReader(readerConfig);
|
||||
|
||||
DataReader dataReader(readerConfig);
|
||||
dataReader.StartMinibatchLoop(minibatchSize, 0, requestDataSize);
|
||||
int count = 0;
|
||||
while (dataReader.GetMinibatch(matrices))
|
||||
|
|
|
@ -155,11 +155,11 @@ void DoTrain(const ConfigRecordType& config)
|
|||
RuntimeError("No network builder found in the config file. NDLNetworkBuilder or SimpleNetworkBuilde must be specified");
|
||||
}
|
||||
|
||||
auto dataReader = CreateObject<DataReader<ElemType>>(config, L"reader");
|
||||
auto dataReader = CreateObject<DataReader>(config, L"reader");
|
||||
|
||||
shared_ptr<DataReader<ElemType>> cvDataReader;
|
||||
shared_ptr<DataReader> cvDataReader;
|
||||
if (config.Exists(L"cvReader"))
|
||||
cvDataReader = CreateObject<DataReader<ElemType>>(config, L"cvReader");
|
||||
cvDataReader = CreateObject<DataReader>(config, L"cvReader");
|
||||
|
||||
shared_ptr<SGD<ElemType>> optimizer;
|
||||
if (config.Exists(L"optimizer"))
|
||||
|
@ -225,15 +225,15 @@ void DoAdapt(const ConfigParameters& config)
|
|||
ConfigParameters readerConfig(config(L"reader"));
|
||||
readerConfig.Insert("traceLevel", config(L"traceLevel", "0"));
|
||||
|
||||
DataReader<ElemType>* dataReader = new DataReader<ElemType>(readerConfig);
|
||||
auto dataReader = make_shared<DataReader>(readerConfig);
|
||||
|
||||
DataReader<ElemType>* cvDataReader = nullptr;
|
||||
shared_ptr<DataReader> cvDataReader;
|
||||
ConfigParameters cvReaderConfig(config(L"cvReader", L""));
|
||||
|
||||
if (cvReaderConfig.size() != 0)
|
||||
{
|
||||
cvReaderConfig.Insert("traceLevel", config(L"traceLevel", "0"));
|
||||
cvDataReader = new DataReader<ElemType>(cvReaderConfig);
|
||||
cvDataReader = make_shared<DataReader>(cvReaderConfig);
|
||||
}
|
||||
|
||||
wstring origModelFileName = config(L"origModelFileName", L"");
|
||||
|
@ -241,10 +241,7 @@ void DoAdapt(const ConfigParameters& config)
|
|||
|
||||
SGD<ElemType> sgd(configSGD);
|
||||
|
||||
sgd.Adapt(origModelFileName, refNodeName, dataReader, cvDataReader, deviceId, makeMode);
|
||||
|
||||
delete dataReader;
|
||||
delete cvDataReader;
|
||||
sgd.Adapt(origModelFileName, refNodeName, dataReader.get(), cvDataReader.get(), deviceId, makeMode);
|
||||
}
|
||||
|
||||
template void DoAdapt<float>(const ConfigParameters& config);
|
||||
|
|
|
@ -612,25 +612,17 @@ int wmainOldCNTKConfig(int argc, wchar_t* argv[]) // called from wmain which is
|
|||
std::string type = config(L"precision", "float");
|
||||
// accept old precision key for backward compatibility
|
||||
if (config.Exists("type"))
|
||||
{
|
||||
type = config(L"type", "float");
|
||||
}
|
||||
InvalidArgument("CNTK: Use of 'type' parameter is deprecated, it is called 'precision' now.");
|
||||
|
||||
fprintf(stderr, "\nPrecision = \"%s\"\n", type.c_str());
|
||||
if (type == "float")
|
||||
{
|
||||
DoCommands<float>(config);
|
||||
}
|
||||
else if (type == "double")
|
||||
{
|
||||
DoCommands<double>(config);
|
||||
}
|
||||
else
|
||||
{
|
||||
RuntimeError("CNTK: Invalid precision string: \"%s\", must be \"float\" or \"double\"", type.c_str());
|
||||
}
|
||||
|
||||
// still here , write a DoneFile if necessary
|
||||
// if completed then write a DoneFile if requested
|
||||
if (!DoneFile.empty())
|
||||
{
|
||||
FILE* fp = fopenOrDie(DoneFile.c_str(), L"w");
|
||||
|
|
|
@ -107,7 +107,7 @@ void TestReader(const ConfigParameters& configBase)
|
|||
epochSize = requestDataSize;
|
||||
}
|
||||
|
||||
DataReader<ElemType> dataReader(readerConfig);
|
||||
DataReader dataReader(readerConfig);
|
||||
|
||||
// get names of features and labels
|
||||
std::vector<std::wstring> featureNames;
|
||||
|
@ -171,7 +171,7 @@ void TestSequenceReader(const ConfigParameters& configBase)
|
|||
std::vector<std::wstring> labelNames;
|
||||
GetFileConfigNames(readerConfig, featureNames, labelNames);
|
||||
|
||||
DataReader<ElemType> dataReader(readerConfig);
|
||||
DataReader dataReader(readerConfig);
|
||||
|
||||
// get names of features and labels
|
||||
std::vector<std::wstring> files;
|
||||
|
|
|
@ -15,32 +15,24 @@
|
|||
#include "DataReader.h"
|
||||
#include "Config.h"
|
||||
#include "ScriptableObjects.h"
|
||||
#include <string>
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
template <class ElemType>
|
||||
std::string GetReaderName(ElemType)
|
||||
static const char* GetReaderName(const string& precision)
|
||||
{
|
||||
return std::string();
|
||||
}
|
||||
template <>
|
||||
std::string GetReaderName(float)
|
||||
{
|
||||
std::string name = "GetReaderF";
|
||||
return name;
|
||||
}
|
||||
template <>
|
||||
std::string GetReaderName(double)
|
||||
{
|
||||
std::string name = "GetReaderD";
|
||||
return name;
|
||||
if (precision == "float")
|
||||
return "GetReaderF";
|
||||
else if (precision == "double")
|
||||
return "GetReaderD";
|
||||
else
|
||||
InvalidArgument("DataReader: The 'precision' parameter must be 'float' or 'double'.");
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
template <class ConfigRecordType>
|
||||
void DataReader<ElemType>::InitFromConfig(const ConfigRecordType& /*config*/)
|
||||
void DataReader::InitFromConfig(const ConfigRecordType& /*config*/)
|
||||
{
|
||||
RuntimeError("Init shouldn't be called, use constructor");
|
||||
// not implemented, calls the underlying class instead
|
||||
|
@ -48,8 +40,7 @@ void DataReader<ElemType>::InitFromConfig(const ConfigRecordType& /*config*/)
|
|||
|
||||
// Destroy - cleanup and remove this class
|
||||
// NOTE: this destroys the object, and it can't be used past this point
|
||||
template <class ElemType>
|
||||
void DataReader<ElemType>::Destroy()
|
||||
void DataReader::Destroy()
|
||||
{
|
||||
// newer code that explicitly place multiple streams for inputs
|
||||
foreach_index (i, m_ioNames) // inputNames should map to node names
|
||||
|
@ -60,14 +51,15 @@ void DataReader<ElemType>::Destroy()
|
|||
|
||||
// DataReader Constructor
|
||||
// options - [in] string of options (i.e. "-windowsize:11 -addenergy") data reader specific
|
||||
template <class ElemType>
|
||||
template <class ConfigRecordType>
|
||||
DataReader<ElemType>::DataReader(const ConfigRecordType& config)
|
||||
DataReader::DataReader(const ConfigRecordType& config)
|
||||
{
|
||||
typedef void (*GetReaderProc)(IDataReader<ElemType>** preader);
|
||||
typedef void (*GetReaderProc)(IDataReader** preader);
|
||||
|
||||
assert(m_dataReaders.empty());
|
||||
|
||||
string precision = config(L"precision", "float");
|
||||
|
||||
bool hasMultipleReaders = config.Exists(L"readers");
|
||||
if (hasMultipleReaders)
|
||||
{
|
||||
|
@ -77,7 +69,7 @@ DataReader<ElemType>::DataReader(const ConfigRecordType& config)
|
|||
{
|
||||
const ConfigRecordType& thisIO = config(ioName);
|
||||
// get the name for the reader we want to use, default to UCIFastReader
|
||||
GetReaderProc getReaderProc = (GetReaderProc) Plugin::Load(thisIO(L"readerType", L"UCIFastReader"), GetReaderName((ElemType) 0));
|
||||
GetReaderProc getReaderProc = (GetReaderProc) Plugin::Load(thisIO(L"readerType", L"UCIFastReader"), GetReaderName(precision));
|
||||
m_ioNames.push_back(ioName);
|
||||
assert(getReaderProc != nullptr);
|
||||
getReaderProc(&m_dataReaders[ioName]); // instantiates the reader with the default constructor (no config processed at this point)
|
||||
|
@ -88,7 +80,7 @@ DataReader<ElemType>::DataReader(const ConfigRecordType& config)
|
|||
wstring ioName = L"ioName";
|
||||
// backward support to use only one type of data reader
|
||||
// get the name for the reader we want to use, default to UCIFastReader
|
||||
GetReaderProc getReaderProc = (GetReaderProc) Plugin::Load(config(L"readerType", L"UCIFastReader"), GetReaderName((ElemType) 0));
|
||||
GetReaderProc getReaderProc = (GetReaderProc)Plugin::Load(config(L"readerType", L"UCIFastReader"), GetReaderName(precision));
|
||||
m_ioNames.push_back(ioName);
|
||||
assert(getReaderProc != nullptr);
|
||||
getReaderProc(&m_dataReaders[ioName]);
|
||||
|
@ -109,14 +101,11 @@ DataReader<ElemType>::DataReader(const ConfigRecordType& config)
|
|||
}
|
||||
}
|
||||
|
||||
template DataReader<float>::DataReader(const ConfigParameters&);
|
||||
template DataReader<double>::DataReader(const ConfigParameters&);
|
||||
template DataReader<float>::DataReader(const ScriptableObjects::IConfigRecord&);
|
||||
template DataReader<double>::DataReader(const ScriptableObjects::IConfigRecord&);
|
||||
template DataReader::DataReader(const ConfigParameters&);
|
||||
template DataReader::DataReader(const ScriptableObjects::IConfigRecord&);
|
||||
|
||||
// destructor - cleanup temp files, etc.
|
||||
template <class ElemType>
|
||||
DataReader<ElemType>::~DataReader()
|
||||
DataReader::~DataReader()
|
||||
{
|
||||
// free up resources
|
||||
for (size_t i = 0; i < m_ioNames.size(); i++)
|
||||
|
@ -127,16 +116,14 @@ DataReader<ElemType>::~DataReader()
|
|||
// mbSize - [in] size of the minibatch (number of frames, etc.)
|
||||
// epoch - [in] epoch number for this loop
|
||||
// requestedEpochSamples - [in] number of samples to randomize, defaults to requestDataSize which uses the number of samples there are in the dataset
|
||||
template <class ElemType>
|
||||
void DataReader<ElemType>::StartMinibatchLoop(size_t mbSize, size_t epoch, size_t requestedEpochSamples)
|
||||
void DataReader::StartMinibatchLoop(size_t mbSize, size_t epoch, size_t requestedEpochSamples)
|
||||
{
|
||||
for (size_t i = 0; i < m_ioNames.size(); i++)
|
||||
m_dataReaders[m_ioNames[i]]->StartMinibatchLoop(mbSize, epoch, requestedEpochSamples);
|
||||
}
|
||||
|
||||
//SupportsDistributedMBRead - Tells if the reader supports distributed minibatch reading for parallel training
|
||||
template <class ElemType>
|
||||
bool DataReader<ElemType>::SupportsDistributedMBRead() const
|
||||
bool DataReader::SupportsDistributedMBRead() const
|
||||
{
|
||||
bool supportsDistributedMBRead = true;
|
||||
for (size_t i = 0; i < m_ioNames.size(); i++)
|
||||
|
@ -156,8 +143,7 @@ bool DataReader<ElemType>::SupportsDistributedMBRead() const
|
|||
// subsetNum - [in] the subset number of the current node in a group of parallel training nodes
|
||||
// numSubsets - [in] total number of nodes participating in the parallel training
|
||||
// requestedEpochSamples - [in] number of samples to randomize, defaults to requestDataSize which uses the number of samples there are in the dataset
|
||||
template <class ElemType>
|
||||
void DataReader<ElemType>::StartDistributedMinibatchLoop(size_t mbSize, size_t epoch, size_t subsetNum, size_t numSubsets, size_t requestedEpochSamples /* = requestDataSize*/)
|
||||
void DataReader::StartDistributedMinibatchLoop(size_t mbSize, size_t epoch, size_t subsetNum, size_t numSubsets, size_t requestedEpochSamples /* = requestDataSize*/)
|
||||
{
|
||||
for (size_t i = 0; i < m_ioNames.size(); i++)
|
||||
{
|
||||
|
@ -169,8 +155,7 @@ void DataReader<ElemType>::StartDistributedMinibatchLoop(size_t mbSize, size_t e
|
|||
// matrices - [in] a map with named matrix types (i.e. 'features', 'labels') mapped to the corresponding matrix,
|
||||
// [out] each matrix resized if necessary containing data.
|
||||
// returns - true if there are more minibatches, false if no more minibatchs remain
|
||||
template <class ElemType>
|
||||
bool DataReader<ElemType>::GetMinibatch(StreamMinibatchInputs& matrices)
|
||||
bool DataReader::GetMinibatch(StreamMinibatchInputs& matrices)
|
||||
{
|
||||
/**
|
||||
each reader reads data with number of columns as nbr_utterances_per_minibatch * mbSize
|
||||
|
@ -193,7 +178,7 @@ bool DataReader<ElemType>::GetMinibatch(StreamMinibatchInputs& matrices)
|
|||
if (nbr == 0)
|
||||
nbr = thisNbr;
|
||||
else if (thisNbr != nbr)
|
||||
LogicError("DataReader<ElemType>::GetMinibatch: The specified number of utterances per minibatch is not consistent to the actual number of utterances per minibatch");
|
||||
LogicError("DataReader::GetMinibatch: The specified number of utterances per minibatch is not consistent to the actual number of utterances per minibatch");
|
||||
}
|
||||
return bRet;
|
||||
}
|
||||
|
@ -203,38 +188,31 @@ bool DataReader<ElemType>::GetMinibatch(StreamMinibatchInputs& matrices)
|
|||
// uids - lables stored in size_t vector instead of ElemType matrix
|
||||
// boundary - phone boundaries
|
||||
// returns - true if there are more minibatches, false if no more minibatchs remain
|
||||
template <class ElemType>
|
||||
bool DataReader<ElemType>::GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput, vector<size_t>& uids, vector<size_t>& boundaries, vector<size_t>& extrauttmap)
|
||||
bool DataReader::GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput, vector<size_t>& uids, vector<size_t>& boundaries, vector<size_t>& extrauttmap)
|
||||
{
|
||||
bool bRet = true;
|
||||
for (size_t i = 0; i < m_ioNames.size(); i++)
|
||||
{
|
||||
bRet &= m_dataReaders[m_ioNames[i]]->GetMinibatch4SE(latticeinput, uids, boundaries, extrauttmap);
|
||||
}
|
||||
return bRet;
|
||||
}
|
||||
|
||||
// GetHmmData - Get the HMM definition for SE training
|
||||
// hmm - HMM definition
|
||||
// returns - true if succeed
|
||||
template <class ElemType>
|
||||
bool DataReader<ElemType>::GetHmmData(msra::asr::simplesenonehmm* hmm)
|
||||
bool DataReader::GetHmmData(msra::asr::simplesenonehmm* hmm)
|
||||
{
|
||||
bool bRet = true;
|
||||
for (size_t i = 0; i < m_ioNames.size(); i++)
|
||||
{
|
||||
bRet &= m_dataReaders[m_ioNames[i]]->GetHmmData(hmm);
|
||||
}
|
||||
return bRet;
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
size_t DataReader<ElemType>::GetNumParallelSequences()
|
||||
size_t DataReader::GetNumParallelSequences()
|
||||
{
|
||||
size_t nNbr = 0;
|
||||
for (size_t i = 0; i < m_ioNames.size(); i++)
|
||||
{
|
||||
IDataReader<ElemType>* ptr = m_dataReaders[m_ioNames[i]];
|
||||
IDataReader* ptr = m_dataReaders[m_ioNames[i]];
|
||||
if (nNbr == 0)
|
||||
nNbr = ptr->GetNumParallelSequences();
|
||||
else if (nNbr != ptr->GetNumParallelSequences())
|
||||
|
@ -243,37 +221,13 @@ size_t DataReader<ElemType>::GetNumParallelSequences()
|
|||
return nNbr;
|
||||
}
|
||||
|
||||
#if 0
|
||||
template <class ElemType>
|
||||
bool DataReader<ElemType>::RequireSentenceSeg() const
|
||||
{
|
||||
bool ans = false;
|
||||
for (size_t i = 0; i < m_ioNames.size(); i++)
|
||||
ans = ans || m_dataReaders.find(m_ioNames[i])->second->RequireSentenceSeg(); // can't say m_dataReaders[] since that is non-const...
|
||||
return ans;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <class ElemType>
|
||||
void DataReader<ElemType>::InitProposals(StreamMinibatchInputs* matrices)
|
||||
void DataReader::InitProposals(StreamMinibatchInputs* matrices)
|
||||
{
|
||||
for (size_t i = 0; i < m_ioNames.size(); i++)
|
||||
m_dataReaders[m_ioNames[i]]->InitProposals(matrices);
|
||||
}
|
||||
|
||||
#if 0
|
||||
template <class ElemType>
|
||||
int DataReader<ElemType>::GetSentenceEndIdFromOutputLabel()
|
||||
{
|
||||
int iRet = -1;
|
||||
for (size_t i = 0; i < m_ioNames.size(); i++)
|
||||
iRet = m_dataReaders[m_ioNames[i]]->GetSentenceEndIdFromOutputLabel();
|
||||
return iRet;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <class ElemType>
|
||||
bool DataReader<ElemType>::GetProposalObs(StreamMinibatchInputs* matrices, const size_t tidx, vector<size_t>& history)
|
||||
bool DataReader::GetProposalObs(StreamMinibatchInputs* matrices, const size_t tidx, vector<size_t>& history)
|
||||
{
|
||||
bool bRet = true;
|
||||
for (size_t i = 0; i < m_ioNames.size(); i++)
|
||||
|
@ -281,23 +235,20 @@ bool DataReader<ElemType>::GetProposalObs(StreamMinibatchInputs* matrices, const
|
|||
return bRet;
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
void DataReader<ElemType>::CopyMBLayoutTo(MBLayoutPtr pMBLayout)
|
||||
void DataReader::CopyMBLayoutTo(MBLayoutPtr pMBLayout)
|
||||
{
|
||||
// BUGBUG: This copies all data reader's layout info on top of each other, keeping only the last one; likely not what was intended.
|
||||
for (size_t i = 0; i < m_ioNames.size(); i++)
|
||||
m_dataReaders[m_ioNames[i]]->CopyMBLayoutTo(pMBLayout);
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
void DataReader<ElemType>::SetRandomSeed(int seed)
|
||||
void DataReader::SetRandomSeed(int seed)
|
||||
{
|
||||
for (size_t i = 0; i < m_ioNames.size(); i++)
|
||||
m_dataReaders[m_ioNames[i]]->SetRandomSeed(seed);
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
bool DataReader<ElemType>::GetMinibatchCopy(
|
||||
bool DataReader::GetMinibatchCopy(
|
||||
std::vector<std::vector<std::pair<wstring, size_t>>>& uttInfo,
|
||||
StreamMinibatchInputs& matrices,
|
||||
MBLayoutPtr pMBLayout)
|
||||
|
@ -308,8 +259,7 @@ bool DataReader<ElemType>::GetMinibatchCopy(
|
|||
return ans;
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
bool DataReader<ElemType>::SetNetOutput(
|
||||
bool DataReader::SetNetOutput(
|
||||
const std::vector<std::vector<std::pair<wstring, size_t>>>& uttInfo,
|
||||
const MatrixBase& outputs,
|
||||
const MBLayoutPtr pMBLayout)
|
||||
|
@ -322,8 +272,7 @@ bool DataReader<ElemType>::SetNetOutput(
|
|||
|
||||
// GetLabelMapping - Gets the label mapping from integer index to label type
|
||||
// returns - a map from numeric datatype to native label type
|
||||
template <class ElemType>
|
||||
const std::map<typename DataReader<ElemType>::LabelIdType, typename DataReader<ElemType>::LabelType>& DataReader<ElemType>::GetLabelMapping(const std::wstring&)
|
||||
const std::map<typename DataReader::LabelIdType, typename DataReader::LabelType>& DataReader::GetLabelMapping(const std::wstring&)
|
||||
{
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
|
@ -331,8 +280,7 @@ const std::map<typename DataReader<ElemType>::LabelIdType, typename DataReader<E
|
|||
// SetLabelMapping - Sets the label mapping from integer index to label
|
||||
// labelMapping - mapping table from label values to IDs (must be 0-n)
|
||||
// note: for tasks with labels, the mapping table must be the same between a training run and a testing run
|
||||
template <class ElemType>
|
||||
void DataReader<ElemType>::SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, LabelType>& labelMapping)
|
||||
void DataReader::SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, LabelType>& labelMapping)
|
||||
{
|
||||
for (size_t i = 0; i < m_ioNames.size(); i++)
|
||||
m_dataReaders[m_ioNames[i]]->SetLabelMapping(sectionName, labelMapping);
|
||||
|
@ -346,8 +294,7 @@ void DataReader<ElemType>::SetLabelMapping(const std::wstring& sectionName, cons
|
|||
// [out] size of buffer filled with data
|
||||
// recordStart - record to start reading from, defaults to zero (start of data)
|
||||
// returns: true if data remains to be read, false if the end of data was reached
|
||||
template <class ElemType>
|
||||
bool DataReader<ElemType>::GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart)
|
||||
bool DataReader::GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart)
|
||||
{
|
||||
bool bRet = true;
|
||||
for (size_t i = 0; i < m_ioNames.size(); i++)
|
||||
|
@ -355,8 +302,7 @@ bool DataReader<ElemType>::GetData(const std::wstring& sectionName, size_t numRe
|
|||
return bRet;
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
bool DataReader<ElemType>::DataEnd()
|
||||
bool DataReader::DataEnd()
|
||||
{
|
||||
bool bRet = true;
|
||||
for (size_t i = 0; i < m_ioNames.size(); i++)
|
||||
|
@ -364,10 +310,7 @@ bool DataReader<ElemType>::DataEnd()
|
|||
return bRet;
|
||||
}
|
||||
|
||||
//The explicit instantiation
|
||||
template class DataReader<double>;
|
||||
template class DataReader<float>;
|
||||
|
||||
// register SGD<> with the ScriptableObject system
|
||||
ScriptableObjects::ConfigurableRuntimeTypeRegister::AddFloatDouble<DataReader<float>, DataReader<double>> registerDataReaderPlugin(L"DataReaderPlugin");
|
||||
} } }
|
||||
ScriptableObjects::ConfigurableRuntimeTypeRegister::Add<DataReader> registerDataReaderPlugin(L"DataReaderPlugin");
|
||||
|
||||
}}}
|
||||
|
|
|
@ -11,29 +11,18 @@
|
|||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
template <class ElemType>
|
||||
std::string GetWriterName(ElemType)
|
||||
static const char* GetWriterName(const string& precision)
|
||||
{
|
||||
std::string empty;
|
||||
return empty;
|
||||
if (precision == "float")
|
||||
return "GetWriterF";
|
||||
else if (precision == "double")
|
||||
return "GetWriterD";
|
||||
else
|
||||
InvalidArgument("DataWriter: The 'precision' parameter must be 'float' or 'double'.");
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string GetWriterName(float)
|
||||
{
|
||||
std::string name = "GetWriterF";
|
||||
return name;
|
||||
}
|
||||
template <>
|
||||
std::string GetWriterName(double)
|
||||
{
|
||||
std::string name = "GetWriterD";
|
||||
return name;
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
template <class ConfigRecordType>
|
||||
void DataWriter<ElemType>::InitFromConfig(const ConfigRecordType& /*config*/)
|
||||
void DataWriter::InitFromConfig(const ConfigRecordType& /*config*/)
|
||||
{
|
||||
RuntimeError("Init shouldn't be called, use constructor");
|
||||
// not implemented, calls the underlying class instead
|
||||
|
@ -41,22 +30,17 @@ void DataWriter<ElemType>::InitFromConfig(const ConfigRecordType& /*config*/)
|
|||
|
||||
// Destroy - cleanup and remove this class
|
||||
// NOTE: this destroys the object, and it can't be used past this point
|
||||
template <class ElemType>
|
||||
void DataWriter<ElemType>::Destroy()
|
||||
void DataWriter::Destroy()
|
||||
{
|
||||
m_dataWriter->Destroy();
|
||||
}
|
||||
|
||||
// DataWriter Constructor
|
||||
// config - [in] configuration data for the data writer
|
||||
template <class ElemType>
|
||||
template <class ConfigRecordType>
|
||||
DataWriter<ElemType>::DataWriter(const ConfigRecordType& config)
|
||||
DataWriter::DataWriter(const ConfigRecordType& config)
|
||||
{
|
||||
typedef void (*GetWriterProc)(IDataWriter<ElemType>** pwriter);
|
||||
|
||||
// initialize just in case
|
||||
m_dataWriter = NULL;
|
||||
typedef void (*GetWriterProc)(IDataWriter** pwriter);
|
||||
|
||||
// get the name for the writer we want to use, default to BinaryWriter (which is in BinaryReader.dll)
|
||||
// TODO: This seems like a find-replace operation?
|
||||
|
@ -82,21 +66,20 @@ DataWriter<ElemType>::DataWriter(const ConfigRecordType& config)
|
|||
writerType = L"KaldiReader";
|
||||
}
|
||||
|
||||
ElemType elemType = ElemType();
|
||||
GetWriterProc getWriterProc = (GetWriterProc) Plugin::Load(writerType, GetWriterName(elemType));
|
||||
string precision = config(L"precision", "float");
|
||||
|
||||
GetWriterProc getWriterProc = (GetWriterProc)Plugin::Load(writerType, GetWriterName(precision));
|
||||
m_dataWriter = NULL;
|
||||
getWriterProc(&m_dataWriter);
|
||||
|
||||
m_dataWriter->Init(config);
|
||||
}
|
||||
|
||||
template DataWriter<float>::DataWriter(const ConfigParameters&);
|
||||
template DataWriter<double>::DataWriter(const ConfigParameters&);
|
||||
template DataWriter<float>::DataWriter(const ScriptableObjects::IConfigRecord&);
|
||||
template DataWriter<double>::DataWriter(const ScriptableObjects::IConfigRecord&);
|
||||
template DataWriter::DataWriter(const ConfigParameters&);
|
||||
template DataWriter::DataWriter(const ScriptableObjects::IConfigRecord&);
|
||||
|
||||
// destructor - cleanup temp files, etc.
|
||||
template <class ElemType>
|
||||
DataWriter<ElemType>::~DataWriter()
|
||||
DataWriter::~DataWriter()
|
||||
{
|
||||
// free up resources
|
||||
if (m_dataWriter)
|
||||
|
@ -105,8 +88,7 @@ DataWriter<ElemType>::~DataWriter()
|
|||
|
||||
// GetSections - Get the sections of the file
|
||||
// sections - a map of section name to section. Data sepcifications from config file will be used to determine where and how to save data
|
||||
template <class ElemType>
|
||||
void DataWriter<ElemType>::GetSections(std::map<std::wstring, SectionType, nocase_compare>& sections)
|
||||
void DataWriter::GetSections(std::map<std::wstring, SectionType, nocase_compare>& sections)
|
||||
{
|
||||
m_dataWriter->GetSections(sections);
|
||||
}
|
||||
|
@ -117,8 +99,7 @@ void DataWriter<ElemType>::GetSections(std::map<std::wstring, SectionType, nocas
|
|||
// numRecords - number of records we are saving, can be zero if not applicable
|
||||
// datasetSize - Size of the dataset
|
||||
// byteVariableSized - for variable sized data, size of current block to be written, zero when not used, or ignored if not variable sized data
|
||||
template <class ElemType>
|
||||
bool DataWriter<ElemType>::SaveData(size_t recordStart, const std::map<std::wstring, void*, nocase_compare>& matrices, size_t numRecords, size_t datasetSize, size_t byteVariableSized)
|
||||
bool DataWriter::SaveData(size_t recordStart, const std::map<std::wstring, void*, nocase_compare>& matrices, size_t numRecords, size_t datasetSize, size_t byteVariableSized)
|
||||
{
|
||||
return m_dataWriter->SaveData(recordStart, matrices, numRecords, datasetSize, byteVariableSized);
|
||||
}
|
||||
|
@ -126,13 +107,9 @@ bool DataWriter<ElemType>::SaveData(size_t recordStart, const std::map<std::wstr
|
|||
// SaveMapping - save a map into the file
|
||||
// saveId - name of the section to save into (section:subsection format)
|
||||
// labelMapping - map we are saving to the file
|
||||
template <class ElemType>
|
||||
void DataWriter<ElemType>::SaveMapping(std::wstring saveId, const std::map<LabelIdType, LabelType>& labelMapping)
|
||||
void DataWriter::SaveMapping(std::wstring saveId, const std::map<LabelIdType, LabelType>& labelMapping)
|
||||
{
|
||||
m_dataWriter->SaveMapping(saveId, labelMapping);
|
||||
}
|
||||
|
||||
//The explicit instantiation
|
||||
template class DataWriter<double>;
|
||||
template class DataWriter<float>;
|
||||
} } }
|
||||
}}}
|
||||
|
|
|
@ -85,14 +85,13 @@ public:
|
|||
|
||||
// Data Reader interface
|
||||
// implemented by DataReader and underlying classes
|
||||
// TODO: Remove <ElemType>. Only one method to go: SetNetOutput().
|
||||
template <class ElemType>
|
||||
class DATAREADER_API IDataReader
|
||||
{
|
||||
public:
|
||||
typedef std::string LabelType; // surface form of an input token
|
||||
typedef unsigned int LabelIdType; // input token mapped to an integer --TODO: why not size_t? Does this save space?
|
||||
|
||||
// BUGBUG: We should not have data members in an interace!
|
||||
unsigned m_seed;
|
||||
size_t mRequestedNumParallelSequences; // number of desired parallel sequences in each minibatch
|
||||
|
||||
|
@ -199,27 +198,20 @@ public:
|
|||
return false;
|
||||
}
|
||||
};
|
||||
typedef std::shared_ptr<IDataReader> IDataReaderPtr;
|
||||
|
||||
// GetReader - get a reader type from the DLL
|
||||
// since we have 2 reader types based on template parameters, exposes 2 exports
|
||||
// could be done directly the templated name, but that requires mangled C++ names
|
||||
template <class ElemType>
|
||||
void DATAREADER_API GetReader(IDataReader<ElemType>** preader);
|
||||
extern "C" DATAREADER_API void GetReaderF(IDataReader<float>** preader);
|
||||
extern "C" DATAREADER_API void GetReaderD(IDataReader<double>** preader);
|
||||
// GetReaderX() - get a reader type from the DLL
|
||||
// The F version gets the 'float' version, and D gets 'double'.
|
||||
extern "C" DATAREADER_API void GetReaderF(IDataReader** preader);
|
||||
extern "C" DATAREADER_API void GetReaderD(IDataReader** preader);
|
||||
|
||||
// Data Reader class
|
||||
// interface for clients of the Data Reader
|
||||
// mirrors the IDataReader interface, except the Init method is private (use the constructor)
|
||||
template <class ElemType>
|
||||
class DataReader : public IDataReader<ElemType>, protected Plugin, public ScriptableObjects::Object
|
||||
class DataReader : public IDataReader, protected Plugin, public ScriptableObjects::Object
|
||||
{
|
||||
typedef typename IDataReader<ElemType>::LabelType LabelType;
|
||||
typedef typename IDataReader<ElemType>::LabelIdType LabelIdType;
|
||||
|
||||
private:
|
||||
vector<wstring> m_ioNames; // TODO: why are these needed, why not loop over m_dataReaders?
|
||||
map<wstring, IDataReader<ElemType>*> m_dataReaders; // readers
|
||||
map<wstring, IDataReader*> m_dataReaders; // readers
|
||||
|
||||
// Init - Reader Initialize for multiple data sets
|
||||
// config - [in] configuration parameters for the datareader
|
||||
|
|
|
@ -46,7 +46,6 @@ enum SectionType
|
|||
|
||||
// Data Writer interface
|
||||
// implemented by some DataWriters
|
||||
template <class ElemType>
|
||||
class DATAWRITER_API IDataWriter
|
||||
{
|
||||
public:
|
||||
|
@ -64,26 +63,19 @@ public:
|
|||
virtual void SaveMapping(std::wstring saveId, const std::map<LabelIdType, LabelType>& labelMapping) = 0;
|
||||
virtual bool SupportMultiUtterances() const = 0;
|
||||
};
|
||||
typedef std::shared_ptr<IDataWriter> IDataWriterPtr;
|
||||
|
||||
// GetWriter - get a reader type from the DLL
|
||||
// since we have 2 writerr types based on template parameters, exposes 2 exports
|
||||
// could be done directly the templated name, but that requires mangled C++ names
|
||||
template <class ElemType>
|
||||
void DATAWRITER_API GetWriter(IDataWriter<ElemType>** pwriter);
|
||||
extern "C" DATAWRITER_API void GetWriterF(IDataWriter<float>** pwriter);
|
||||
extern "C" DATAWRITER_API void GetWriterD(IDataWriter<double>** pwriter);
|
||||
// The F version gets the 'float' version, and D gets 'double'.
|
||||
extern "C" DATAWRITER_API void GetWriterF(IDataWriter** pwriter);
|
||||
extern "C" DATAWRITER_API void GetWriterD(IDataWriter** pwriter);
|
||||
|
||||
// Data Writer class
|
||||
// interface for clients of the Data Writer
|
||||
// mirrors the IDataWriter interface, except the Init method is private (use the constructor)
|
||||
template <class ElemType>
|
||||
class DataWriter : public IDataWriter<ElemType>, protected Plugin
|
||||
class DataWriter : public IDataWriter, protected Plugin
|
||||
{
|
||||
typedef typename IDataWriter<ElemType>::LabelType LabelType;
|
||||
typedef typename IDataWriter<ElemType>::LabelIdType LabelIdType;
|
||||
|
||||
private:
|
||||
IDataWriter<ElemType>* m_dataWriter; // writer
|
||||
IDataWriter* m_dataWriter; // writer
|
||||
|
||||
// Init - Writer Initialize for multiple data sets
|
||||
// config - [in] configuration parameters for the datawriter
|
||||
|
|
|
@ -574,6 +574,7 @@ public:
|
|||
|
||||
// helper to access to element(0,0) without having to type-cast
|
||||
virtual double Get00Element() const = 0;
|
||||
virtual MatrixBasePtr ValuePtr() const = 0; // for use in readers that pass the agnostic object around
|
||||
|
||||
// TODO: two sets of functions, choose one
|
||||
const std::wstring& NodeName() const { return m_nodeName; }
|
||||
|
@ -1094,7 +1095,8 @@ public:
|
|||
const Matrix<ElemType>& Value() const { return *m_value; }
|
||||
Matrix<ElemType>& Value() { return *m_value; }
|
||||
|
||||
std::shared_ptr<Matrix<ElemType>> ValuePtr() { return m_value; } // readers want this as a shared_ptr straight
|
||||
MatrixBasePtr ValuePtr() const override final { return m_value; } // readers want this as a shared_ptr straight
|
||||
// Note: We cannot return a const& since returning m_value as a MatrixBasePtr is a type cast that generates a temporary. Interesting.
|
||||
|
||||
const Matrix<ElemType>& Gradient() const { return *m_gradient; }
|
||||
Matrix<ElemType>& Gradient() { return *m_gradient; }
|
||||
|
@ -1677,6 +1679,7 @@ public:
|
|||
virtual void CopyTo(ComputationNodeBasePtr node, const std::wstring& newName, const CopyNodeFlags flags) const override { NOT_IMPLEMENTED; }
|
||||
virtual ComputationNodeBasePtr Duplicate(const std::wstring& newName, const CopyNodeFlags flags) override { NOT_IMPLEMENTED; }
|
||||
virtual double Get00Element() const override { NOT_IMPLEMENTED; }
|
||||
virtual MatrixBasePtr ValuePtr() const override { NOT_IMPLEMENTED; }
|
||||
virtual void UpdateFunctionMBSize() override { NOT_IMPLEMENTED; }
|
||||
virtual void AttachInputs(const std::vector<ComputationNodeBasePtr>& inputs) override { NOT_IMPLEMENTED; }
|
||||
virtual void PrintSelf(bool) const override { NOT_IMPLEMENTED; }
|
||||
|
|
|
@ -12,12 +12,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
// Evaluation Reader class
|
||||
// interface to pass to evaluation DLL
|
||||
template <class ElemType>
|
||||
class EvalReader : public IDataReader<ElemType>
|
||||
class EvalReader : public IDataReader
|
||||
{
|
||||
typedef typename IDataReader<ElemType>::LabelType LabelType;
|
||||
typedef typename IDataReader<ElemType>::LabelIdType LabelIdType;
|
||||
|
||||
private:
|
||||
std::map<std::wstring, std::vector<ElemType>*>* m_inputs; // our input data
|
||||
std::map<std::wstring, size_t>* m_dimensions; // the number of rows for the input data
|
||||
size_t m_recordCount; // count of records in this data
|
||||
|
|
|
@ -12,16 +12,13 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
// Evaluation Writer class
|
||||
// interface to pass to evaluation DLL
|
||||
template <class ElemType>
|
||||
class EvalWriter : public IDataWriter<ElemType>
|
||||
class EvalWriter : public IDataWriter
|
||||
{
|
||||
typedef typename IDataWriter<ElemType>::LabelType LabelType;
|
||||
typedef typename IDataWriter<ElemType>::LabelIdType LabelIdType;
|
||||
|
||||
private:
|
||||
std::map<std::wstring, std::vector<ElemType>*>* m_outputs; // our output data
|
||||
std::map<std::wstring, size_t>* m_dimensions; // the number of rows for the output data
|
||||
size_t m_recordCount; // count of records in this data
|
||||
size_t m_currentRecord; // next record number to read
|
||||
|
||||
public:
|
||||
// Method to setup the data for the reader
|
||||
void SetData(std::map<std::wstring, std::vector<ElemType>*>* outputs, std::map<std::wstring, size_t>* dimensions)
|
||||
|
|
|
@ -541,12 +541,8 @@ public:
|
|||
};
|
||||
|
||||
template <class ElemType>
|
||||
class BinaryReader : public IDataReader<ElemType>
|
||||
class BinaryReader : public IDataReader
|
||||
{
|
||||
typedef typename IDataReader<ElemType>::LabelType LabelType;
|
||||
typedef typename IDataReader<ElemType>::LabelIdType LabelIdType;
|
||||
|
||||
private:
|
||||
size_t m_mbSize; // size of minibatch requested
|
||||
size_t m_mbStartSample; // starting sample # of the next minibatch
|
||||
size_t m_epochSize; // size of an epoch
|
||||
|
@ -615,12 +611,8 @@ public:
|
|||
};
|
||||
|
||||
template <class ElemType>
|
||||
class BinaryWriter : public IDataWriter<ElemType>
|
||||
class BinaryWriter : public IDataWriter
|
||||
{
|
||||
typedef typename IDataWriter<ElemType>::LabelType LabelType;
|
||||
typedef typename IDataWriter<ElemType>::LabelIdType LabelIdType;
|
||||
|
||||
private:
|
||||
int m_traceLevel; // trace level to output the
|
||||
size_t m_recordCurrent;
|
||||
size_t m_recordMax;
|
||||
|
|
|
@ -14,33 +14,22 @@
|
|||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
template <class ElemType>
|
||||
void DATAREADER_API GetReader(IDataReader<ElemType>** preader)
|
||||
extern "C" DATAREADER_API void GetReaderF(IDataReader** preader)
|
||||
{
|
||||
*preader = new BinaryReader<ElemType>();
|
||||
*preader = new BinaryReader<float>();
|
||||
}
|
||||
extern "C" DATAREADER_API void GetReaderD(IDataReader** preader)
|
||||
{
|
||||
*preader = new BinaryReader<double>();
|
||||
}
|
||||
|
||||
extern "C" DATAREADER_API void GetReaderF(IDataReader<float>** preader)
|
||||
extern "C" DATAWRITER_API void GetWriterF(IDataWriter** pwriter)
|
||||
{
|
||||
GetReader(preader);
|
||||
*pwriter = new BinaryWriter<float>();
|
||||
}
|
||||
extern "C" DATAREADER_API void GetReaderD(IDataReader<double>** preader)
|
||||
extern "C" DATAWRITER_API void GetWriterD(IDataWriter** pwriter)
|
||||
{
|
||||
GetReader(preader);
|
||||
*pwriter = new BinaryWriter<double>();
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
void DATAWRITER_API GetWriter(IDataWriter<ElemType>** pwriter)
|
||||
{
|
||||
*pwriter = new BinaryWriter<ElemType>();
|
||||
}
|
||||
|
||||
extern "C" DATAWRITER_API void GetWriterF(IDataWriter<float>** pwriter)
|
||||
{
|
||||
GetWriter(pwriter);
|
||||
}
|
||||
extern "C" DATAWRITER_API void GetWriterD(IDataWriter<double>** pwriter)
|
||||
{
|
||||
GetWriter(pwriter);
|
||||
}
|
||||
} } }
|
||||
}}}
|
||||
|
|
|
@ -418,7 +418,7 @@ bool DSSMReader<ElemType>::GetMinibatch(StreamMinibatchInputs& matrices)
|
|||
// GetLabelMapping - Gets the label mapping from integer index to label type
|
||||
// returns - a map from numeric datatype to native label type
|
||||
template <class ElemType>
|
||||
const std::map<typename IDataReader<ElemType>::LabelIdType, typename IDataReader<ElemType>::LabelType>& DSSMReader<ElemType>::GetLabelMapping(const std::wstring& sectionName)
|
||||
const std::map<IDataReader::LabelIdType, IDataReader::LabelType>& DSSMReader<ElemType>::GetLabelMapping(const std::wstring& sectionName)
|
||||
{
|
||||
if (m_cachingReader)
|
||||
{
|
||||
|
@ -431,7 +431,7 @@ const std::map<typename IDataReader<ElemType>::LabelIdType, typename IDataReader
|
|||
// labelMapping - mapping table from label values to IDs (must be 0-n)
|
||||
// note: for tasks with labels, the mapping table must be the same between a training run and a testing run
|
||||
template <class ElemType>
|
||||
void DSSMReader<ElemType>::SetLabelMapping(const std::wstring& /*sectionName*/, const std::map<typename IDataReader<ElemType>::LabelIdType, typename LabelType>& labelMapping)
|
||||
void DSSMReader<ElemType>::SetLabelMapping(const std::wstring& /*sectionName*/, const std::map<LabelIdType, LabelType>& labelMapping)
|
||||
{
|
||||
if (m_cachingReader)
|
||||
{
|
||||
|
|
|
@ -64,7 +64,7 @@ public:
|
|||
};
|
||||
|
||||
template <class ElemType>
|
||||
class DSSMReader : public IDataReader<ElemType>
|
||||
class DSSMReader : public IDataReader
|
||||
{
|
||||
// public:
|
||||
// typedef std::string LabelType;
|
||||
|
@ -119,8 +119,8 @@ private:
|
|||
std::map<LabelType, LabelIdType> m_mapLabelToId;
|
||||
|
||||
// caching support
|
||||
DataReader<ElemType>* m_cachingReader;
|
||||
DataWriter<ElemType>* m_cachingWriter;
|
||||
DataReader* m_cachingReader;
|
||||
DataWriter* m_cachingWriter;
|
||||
ConfigParameters m_readerConfig;
|
||||
|
||||
size_t RandomizeSweep(size_t epochSample);
|
||||
|
@ -172,7 +172,7 @@ public:
|
|||
}
|
||||
|
||||
virtual const std::map<LabelIdType, LabelType>& GetLabelMapping(const std::wstring& sectionName);
|
||||
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, typename LabelType>& labelMapping);
|
||||
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, LabelType>& labelMapping);
|
||||
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart = 0);
|
||||
|
||||
virtual bool DataEnd();
|
||||
|
|
|
@ -12,18 +12,13 @@
|
|||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
template <class ElemType>
|
||||
void DATAREADER_API GetReader(IDataReader<ElemType>** preader)
|
||||
extern "C" DATAREADER_API void GetReaderF(IDataReader** preader)
|
||||
{
|
||||
*preader = new DSSMReader<ElemType>();
|
||||
*preader = new DSSMReader<float>();
|
||||
}
|
||||
extern "C" DATAREADER_API void GetReaderD(IDataReader** preader)
|
||||
{
|
||||
*preader = new DSSMReader<double>();
|
||||
}
|
||||
|
||||
extern "C" DATAREADER_API void GetReaderF(IDataReader<float>** preader)
|
||||
{
|
||||
GetReader(preader);
|
||||
}
|
||||
extern "C" DATAREADER_API void GetReaderD(IDataReader<double>** preader)
|
||||
{
|
||||
GetReader(preader);
|
||||
}
|
||||
} } }
|
||||
}}}
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
//
|
||||
// DataWriter.cpp : Defines the exported functions for the DLL application.
|
||||
//
|
||||
// TODO: This is similar but not identical to Common/DataWriter.cpp. Why is this DataWriter different? Can it be reconciled?
|
||||
//
|
||||
|
||||
#include "stdafx.h"
|
||||
#include "Basics.h"
|
||||
|
@ -16,59 +18,45 @@
|
|||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
template <class ElemType>
|
||||
void DATAWRITER_API GetWriter(IDataWriter<ElemType>** pwriter)
|
||||
{
|
||||
*pwriter = new HTKMLFWriter<ElemType>();
|
||||
}
|
||||
|
||||
extern "C" DATAWRITER_API void GetWriterF(IDataWriter<float>** pwriter)
|
||||
{
|
||||
GetWriter(pwriter);
|
||||
}
|
||||
extern "C" DATAWRITER_API void GetWriterD(IDataWriter<double>** pwriter)
|
||||
{
|
||||
GetWriter(pwriter);
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
template <class ConfigRecordType>
|
||||
void DataWriter<ElemType>::InitFromConfig(const ConfigRecordType& writerConfig)
|
||||
void DataWriter::InitFromConfig(const ConfigRecordType& writerConfig)
|
||||
{
|
||||
m_dataWriter = new HTKMLFWriter<ElemType>();
|
||||
wstring precision = writerConfig(L"precision", L"float");
|
||||
if (precision == L"float")
|
||||
m_dataWriter = new HTKMLFWriter<float>();
|
||||
else if (precision == L"double")
|
||||
m_dataWriter = new HTKMLFWriter<double>();
|
||||
else
|
||||
InvalidArgument("DataWriter (HTKMLFWriter): The 'precision' parameter must be 'float' or 'double'.");
|
||||
|
||||
m_dataWriter->Init(writerConfig);
|
||||
}
|
||||
|
||||
// Destroy - cleanup and remove this class
|
||||
// NOTE: this destroys the object, and it can't be used past this point
|
||||
template <class ElemType>
|
||||
void DataWriter<ElemType>::Destroy()
|
||||
void DataWriter::Destroy()
|
||||
{
|
||||
delete m_dataWriter;
|
||||
m_dataWriter = NULL;
|
||||
// TODO: do we need to destroy ourselves as well?
|
||||
}
|
||||
|
||||
// DataWriter Constructor
|
||||
// config - [in] configuration data for the data writer
|
||||
template <class ElemType>
|
||||
template <class ConfigRecordType>
|
||||
DataWriter<ElemType>::DataWriter(const ConfigRecordType& config)
|
||||
DataWriter::DataWriter(const ConfigRecordType& config)
|
||||
{
|
||||
Init(config);
|
||||
}
|
||||
|
||||
// destructor - cleanup temp files, etc.
|
||||
template <class ElemType>
|
||||
DataWriter<ElemType>::~DataWriter()
|
||||
DataWriter::~DataWriter()
|
||||
{
|
||||
delete m_dataWriter;
|
||||
m_dataWriter = NULL;
|
||||
Destroy();
|
||||
}
|
||||
|
||||
// GetSections - Get the sections of the file
|
||||
// sections - a map of section name to section. Data sepcifications from config file will be used to determine where and how to save data
|
||||
template <class ElemType>
|
||||
void DataWriter<ElemType>::GetSections(std::map<std::wstring, SectionType, nocase_compare>& sections)
|
||||
void DataWriter::GetSections(std::map<std::wstring, SectionType, nocase_compare>& sections)
|
||||
{
|
||||
m_dataWriter->GetSections(sections);
|
||||
}
|
||||
|
@ -79,8 +67,7 @@ void DataWriter<ElemType>::GetSections(std::map<std::wstring, SectionType, nocas
|
|||
// numRecords - number of records we are saving, can be zero if not applicable
|
||||
// datasetSize - Size of the dataset
|
||||
// byteVariableSized - for variable sized data, size of current block to be written, zero when not used, or ignored if not variable sized data
|
||||
template <class ElemType>
|
||||
bool DataWriter<ElemType>::SaveData(size_t recordStart, const std::map<std::wstring, void*, nocase_compare>& matrices, size_t numRecords, size_t datasetSize, size_t byteVariableSized)
|
||||
bool DataWriter::SaveData(size_t recordStart, const std::map<std::wstring, void*, nocase_compare>& matrices, size_t numRecords, size_t datasetSize, size_t byteVariableSized)
|
||||
{
|
||||
return m_dataWriter->SaveData(recordStart, matrices, numRecords, datasetSize, byteVariableSized);
|
||||
}
|
||||
|
@ -88,13 +75,9 @@ bool DataWriter<ElemType>::SaveData(size_t recordStart, const std::map<std::wstr
|
|||
// SaveMapping - save a map into the file
|
||||
// saveId - name of the section to save into (section:subsection format)
|
||||
// labelMapping - map we are saving to the file
|
||||
template <class ElemType>
|
||||
void DataWriter<ElemType>::SaveMapping(std::wstring saveId, const std::map<LabelIdType, LabelType>& labelMapping)
|
||||
void DataWriter::SaveMapping(std::wstring saveId, const std::map<LabelIdType, LabelType>& labelMapping)
|
||||
{
|
||||
m_dataWriter->SaveMapping(saveId, labelMapping);
|
||||
}
|
||||
|
||||
//The explicit instantiation
|
||||
template class DataWriter<double>;
|
||||
template class DataWriter<float>;
|
||||
} } }
|
||||
}}}
|
|
@ -8,36 +8,31 @@
|
|||
#include "stdafx.h"
|
||||
#include "Basics.h"
|
||||
|
||||
#include "htkfeatio.h" // for reading HTK features
|
||||
#ifdef _WIN32
|
||||
#include "latticearchive.h" // for reading HTK phoneme lattices (MMI training)
|
||||
#endif
|
||||
#include "simplesenonehmm.h" // for MMI scoring
|
||||
#include "msra_mgram.h" // for unigram scores of ground-truth path in sequence training
|
||||
|
||||
#include "rollingwindowsource.h" // minibatch sources
|
||||
#include "chunkevalsource.h"
|
||||
#define DATAREADER_EXPORTS
|
||||
#include "DataReader.h"
|
||||
#define DATAWRITER_EXPORTS
|
||||
#include "HTKMLFReader.h"
|
||||
#include "Config.h"
|
||||
#include "HTKMLFWriter.h"
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
template <class ElemType>
|
||||
void DATAREADER_API GetReader(IDataReader<ElemType>** preader)
|
||||
extern "C" DATAREADER_API void GetReaderF(IDataReader** preader)
|
||||
{
|
||||
*preader = new HTKMLFReader<ElemType>();
|
||||
*preader = new HTKMLFReader<float>();
|
||||
}
|
||||
extern "C" DATAREADER_API void GetReaderD(IDataReader** preader)
|
||||
{
|
||||
*preader = new HTKMLFReader<double>();
|
||||
}
|
||||
|
||||
extern "C" DATAREADER_API void GetReaderF(IDataReader<float>** preader)
|
||||
extern "C" DATAWRITER_API void GetWriterF(IDataWriter** pwriter)
|
||||
{
|
||||
GetReader(preader);
|
||||
*pwriter = new HTKMLFWriter<float>();
|
||||
}
|
||||
extern "C" DATAREADER_API void GetReaderD(IDataReader<double>** preader)
|
||||
extern "C" DATAWRITER_API void GetWriterD(IDataWriter** pwriter)
|
||||
{
|
||||
GetReader(preader);
|
||||
*pwriter = new HTKMLFWriter<double>();
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
// Utility function, in ConfigFile.cpp, but HTKMLFReader doesn't need that code...
|
||||
|
||||
|
@ -58,4 +53,5 @@ void Trim(std::string& str)
|
|||
str.erase(found + 1);
|
||||
}
|
||||
#endif
|
||||
} } }
|
||||
|
||||
}}}
|
|
@ -1732,7 +1732,7 @@ bool HTKMLFReader<ElemType>::ReNewBufferForMultiIO(size_t i)
|
|||
// GetLabelMapping - Gets the label mapping from integer to type in file
|
||||
// mappingTable - a map from numeric datatype to native label type stored as a string
|
||||
template <class ElemType>
|
||||
const std::map<typename IDataReader<ElemType>::LabelIdType, typename IDataReader<ElemType>::LabelType>& HTKMLFReader<ElemType>::GetLabelMapping(const std::wstring& /*sectionName*/)
|
||||
const std::map<IDataReader::LabelIdType, IDataReader::LabelType>& HTKMLFReader<ElemType>::GetLabelMapping(const std::wstring& /*sectionName*/)
|
||||
{
|
||||
return m_idToLabelMap;
|
||||
}
|
||||
|
@ -1741,7 +1741,7 @@ const std::map<typename IDataReader<ElemType>::LabelIdType, typename IDataReader
|
|||
// labelMapping - mapping table from label values to IDs (must be 0-n)
|
||||
// note: for tasks with labels, the mapping table must be the same between a training run and a testing run
|
||||
template <class ElemType>
|
||||
void HTKMLFReader<ElemType>::SetLabelMapping(const std::wstring& /*sectionName*/, const std::map<typename IDataReader<ElemType>::LabelIdType, typename IDataReader<ElemType>::LabelType>& labelMapping)
|
||||
void HTKMLFReader<ElemType>::SetLabelMapping(const std::wstring& /*sectionName*/, const std::map<IDataReader::LabelIdType, IDataReader::LabelType>& labelMapping)
|
||||
{
|
||||
m_idToLabelMap = labelMapping;
|
||||
}
|
||||
|
|
|
@ -9,10 +9,19 @@
|
|||
#include "Config.h" // for intargvector
|
||||
#include "CUDAPageLockedMemAllocator.h"
|
||||
|
||||
#include "htkfeatio.h" // for reading HTK features
|
||||
#ifdef _WIN32
|
||||
#include "latticearchive.h" // for reading HTK phoneme lattices (MMI training)
|
||||
#endif
|
||||
#include "simplesenonehmm.h" // for MMI scoring
|
||||
#include "msra_mgram.h" // for unigram scores of ground-truth path in sequence training
|
||||
#include "rollingwindowsource.h" // minibatch sources
|
||||
#include "chunkevalsource.h"
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
template <class ElemType>
|
||||
class HTKMLFReader : public IDataReader<ElemType>
|
||||
class HTKMLFReader : public IDataReader
|
||||
{
|
||||
private:
|
||||
const static size_t m_htkRandomizeAuto = 0;
|
||||
|
@ -41,8 +50,8 @@ private:
|
|||
size_t m_extraNumSeqs;
|
||||
bool m_noData;
|
||||
bool m_trainOrTest; // if false, in file writing mode
|
||||
using LabelType = typename IDataReader<ElemType>::LabelType;
|
||||
using LabelIdType = typename IDataReader<ElemType>::LabelIdType;
|
||||
using IDataReader::LabelType;
|
||||
using IDataReader::LabelIdType;
|
||||
|
||||
std::map<LabelIdType, LabelType> m_idToLabelMap;
|
||||
|
||||
|
|
|
@ -118,8 +118,8 @@
|
|||
<ClCompile Include="..\..\Common\TimerUtility.cpp">
|
||||
<PrecompiledHeader>NotUsing</PrecompiledHeader>
|
||||
</ClCompile>
|
||||
<ClCompile Include="DataReader.cpp" />
|
||||
<ClCompile Include="DataWriter.cpp" />
|
||||
<ClCompile Include="Exports.cpp" />
|
||||
<ClCompile Include="DataWriterLocal.cpp" />
|
||||
<ClCompile Include="dllmain.cpp">
|
||||
<CompileAsManaged Condition="$(DebugBuild)">false</CompileAsManaged>
|
||||
<PrecompiledHeader Condition="$(DebugBuild)">
|
||||
|
|
|
@ -1,13 +1,11 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
|
||||
<ItemGroup>
|
||||
<ClCompile Include="DataWriter.cpp" />
|
||||
<ClCompile Include="dllmain.cpp" />
|
||||
<ClCompile Include="HTKMLFReader.cpp" />
|
||||
<ClCompile Include="HTKMLFWriter.cpp" />
|
||||
<ClCompile Include="latticearchive.cpp" />
|
||||
<ClCompile Include="stdafx.cpp" />
|
||||
<ClCompile Include="DataReader.cpp" />
|
||||
<ClCompile Include="..\..\Common\TimerUtility.cpp">
|
||||
<Filter>Common</Filter>
|
||||
</ClCompile>
|
||||
|
@ -15,6 +13,8 @@
|
|||
<Filter>Common</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="..\..\Common\ExceptionWithCallStack.cpp" />
|
||||
<ClCompile Include="Exports.cpp" />
|
||||
<ClCompile Include="DataWriterLocal.cpp" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="biggrowablevectors.h" />
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
template <class ElemType>
|
||||
class HTKMLFWriter : public IDataWriter<ElemType>
|
||||
class HTKMLFWriter : public IDataWriter
|
||||
{
|
||||
private:
|
||||
std::vector<size_t> outputDims;
|
||||
|
@ -36,8 +36,6 @@ private:
|
|||
};
|
||||
|
||||
public:
|
||||
using LabelType = typename IDataWriter<ElemType>::LabelType;
|
||||
using LabelIdType = typename IDataWriter<ElemType>::LabelIdType;
|
||||
template <class ConfigRecordType>
|
||||
void InitFromConfig(const ConfigRecordType& writerConfig);
|
||||
virtual void Init(const ConfigParameters& config)
|
||||
|
|
|
@ -22,13 +22,14 @@ auto factory = [](const ConfigParameters& parameters) -> ReaderPtr
|
|||
return std::make_shared<ImageReader>(std::make_shared<HeapMemoryProvider>(), parameters);
|
||||
};
|
||||
|
||||
extern "C" DATAREADER_API void GetReaderF(IDataReader<float>** preader)
|
||||
extern "C" DATAREADER_API void GetReaderF(IDataReader** preader)
|
||||
{
|
||||
*preader = new ReaderShim<float>(factory);
|
||||
}
|
||||
|
||||
extern "C" DATAREADER_API void GetReaderD(IDataReader<double>** preader)
|
||||
extern "C" DATAREADER_API void GetReaderD(IDataReader** preader)
|
||||
{
|
||||
*preader = new ReaderShim<double>(factory);
|
||||
}
|
||||
} } }
|
||||
|
||||
}}}
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
//
|
||||
// DataReader.cpp : Defines the exported functions for the DLL application.
|
||||
//
|
||||
// TODO: Rename to Exports.cpp
|
||||
//
|
||||
|
||||
#include "stdafx.h"
|
||||
#include "basetypes.h"
|
||||
|
@ -23,19 +25,13 @@
|
|||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
template <class ElemType>
|
||||
void DATAREADER_API GetReader(IDataReader<ElemType>** preader)
|
||||
extern "C" DATAREADER_API void GetReaderF(IDataReader** preader)
|
||||
{
|
||||
*preader = new HTKMLFReader<ElemType>();
|
||||
*preader = new HTKMLFReader<float>();
|
||||
}
|
||||
|
||||
extern "C" DATAREADER_API void GetReaderF(IDataReader<float>** preader)
|
||||
extern "C" DATAREADER_API void GetReaderD(IDataReader** preader)
|
||||
{
|
||||
GetReader(preader);
|
||||
}
|
||||
extern "C" DATAREADER_API void GetReaderD(IDataReader<double>** preader)
|
||||
{
|
||||
GetReader(preader);
|
||||
*preader = new HTKMLFReader<double>();
|
||||
}
|
||||
|
||||
// Utility function, in ConfigFile.cpp, but HTKMLFReader doesn't need that code...
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
//
|
||||
// DataWriter.cpp : Defines the exported functions for the DLL application.
|
||||
//
|
||||
// TODO: This is similar but not identical to Common/DataWriter.cpp. Why is this DataWriter different? Can it be reconciled?
|
||||
//
|
||||
|
||||
#include "stdafx.h"
|
||||
#include "basetypes.h"
|
||||
|
@ -16,59 +18,54 @@
|
|||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
template <class ElemType>
|
||||
void DATAWRITER_API GetWriter(IDataWriter<ElemType>** pwriter)
|
||||
// TODO: move these to Exports.cpp
|
||||
extern "C" DATAWRITER_API void GetWriterF(IDataWriter** pwriter)
|
||||
{
|
||||
*pwriter = new HTKMLFWriter<ElemType>();
|
||||
*pwriter = new HTKMLFWriter<float>();
|
||||
}
|
||||
extern "C" DATAWRITER_API void GetWriterD(IDataWriter** pwriter)
|
||||
{
|
||||
*pwriter = new HTKMLFWriter<double>();
|
||||
}
|
||||
|
||||
extern "C" DATAWRITER_API void GetWriterF(IDataWriter<float>** pwriter)
|
||||
{
|
||||
GetWriter(pwriter);
|
||||
}
|
||||
extern "C" DATAWRITER_API void GetWriterD(IDataWriter<double>** pwriter)
|
||||
{
|
||||
GetWriter(pwriter);
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
template <class ConfigRecordType>
|
||||
void DataWriter<ElemType>::InitFromConfig(const ConfigRecordType& writerConfig)
|
||||
void DataWriter::InitFromConfig(const ConfigRecordType& writerConfig)
|
||||
{
|
||||
m_dataWriter = new HTKMLFWriter<ElemType>();
|
||||
wstring precision = writerConfig(L"precision", L"float");
|
||||
if (precision == L"float")
|
||||
m_dataWriter = new HTKMLFWriter<float>();
|
||||
else if (precision == L"double")
|
||||
m_dataWriter = new HTKMLFWriter<double>();
|
||||
else
|
||||
InvalidArgument("DataWriter (Kaldi HTKMLFWriter): The 'precision' parameter must be 'float' or 'double'.");
|
||||
|
||||
m_dataWriter->Init(writerConfig);
|
||||
}
|
||||
|
||||
// Destroy - cleanup and remove this class
|
||||
// NOTE: this destroys the object, and it can't be used past this point
|
||||
template <class ElemType>
|
||||
void DataWriter<ElemType>::Destroy()
|
||||
void DataWriter::Destroy()
|
||||
{
|
||||
delete m_dataWriter;
|
||||
m_dataWriter = NULL;
|
||||
}
|
||||
|
||||
// DataWriter Constructor
|
||||
// config - [in] configuration data for the data writer
|
||||
template <class ElemType>
|
||||
template <class ConfigRecordType>
|
||||
DataWriter<ElemType>::DataWriter(const ConfigRecordType& config)
|
||||
DataWriter::DataWriter(const ConfigRecordType& config)
|
||||
{
|
||||
Init(config);
|
||||
}
|
||||
|
||||
// destructor - cleanup temp files, etc.
|
||||
template <class ElemType>
|
||||
DataWriter<ElemType>::~DataWriter()
|
||||
DataWriter::~DataWriter()
|
||||
{
|
||||
delete m_dataWriter;
|
||||
m_dataWriter = NULL;
|
||||
Destroy();
|
||||
}
|
||||
|
||||
// GetSections - Get the sections of the file
|
||||
// sections - a map of section name to section. Data sepcifications from config file will be used to determine where and how to save data
|
||||
template <class ElemType>
|
||||
void DataWriter<ElemType>::GetSections(std::map<std::wstring, SectionType, nocase_compare>& sections)
|
||||
void DataWriter::GetSections(std::map<std::wstring, SectionType, nocase_compare>& sections)
|
||||
{
|
||||
m_dataWriter->GetSections(sections);
|
||||
}
|
||||
|
@ -79,8 +76,7 @@ void DataWriter<ElemType>::GetSections(std::map<std::wstring, SectionType, nocas
|
|||
// numRecords - number of records we are saving, can be zero if not applicable
|
||||
// datasetSize - Size of the dataset
|
||||
// byteVariableSized - for variable sized data, size of current block to be written, zero when not used, or ignored if not variable sized data
|
||||
template <class ElemType>
|
||||
bool DataWriter<ElemType>::SaveData(size_t recordStart, const std::map<std::wstring, void*, nocase_compare>& matrices, size_t numRecords, size_t datasetSize, size_t byteVariableSized)
|
||||
bool DataWriter::SaveData(size_t recordStart, const std::map<std::wstring, void*, nocase_compare>& matrices, size_t numRecords, size_t datasetSize, size_t byteVariableSized)
|
||||
{
|
||||
return m_dataWriter->SaveData(recordStart, matrices, numRecords, datasetSize, byteVariableSized);
|
||||
}
|
||||
|
@ -88,13 +84,9 @@ bool DataWriter<ElemType>::SaveData(size_t recordStart, const std::map<std::wstr
|
|||
// SaveMapping - save a map into the file
|
||||
// saveId - name of the section to save into (section:subsection format)
|
||||
// labelMapping - map we are saving to the file
|
||||
template <class ElemType>
|
||||
void DataWriter<ElemType>::SaveMapping(std::wstring saveId, const std::map<LabelIdType, LabelType>& labelMapping)
|
||||
void DataWriter::SaveMapping(std::wstring saveId, const std::map<LabelIdType, LabelType>& labelMapping)
|
||||
{
|
||||
m_dataWriter->SaveMapping(saveId, labelMapping);
|
||||
}
|
||||
|
||||
//The explicit instantiation
|
||||
template class DataWriter<double>;
|
||||
template class DataWriter<float>;
|
||||
} } }
|
||||
}}}
|
||||
|
|
|
@ -1916,7 +1916,7 @@ bool HTKMLFReader<ElemType>::SetNetOutput(
|
|||
// GetLabelMapping - Gets the label mapping from integer to type in file
|
||||
// mappingTable - a map from numeric datatype to native label type stored as a string
|
||||
template <class ElemType>
|
||||
const std::map<typename IDataReader<ElemType>::LabelIdType, typename IDataReader<ElemType>::LabelType>& HTKMLFReader<ElemType>::GetLabelMapping(const std::wstring& /*sectionName*/)
|
||||
const std::map<IDataReader::LabelIdType, IDataReader::LabelType>& HTKMLFReader<ElemType>::GetLabelMapping(const std::wstring& /*sectionName*/)
|
||||
{
|
||||
return m_idToLabelMap;
|
||||
}
|
||||
|
@ -1925,7 +1925,7 @@ const std::map<typename IDataReader<ElemType>::LabelIdType, typename IDataReader
|
|||
// labelMapping - mapping table from label values to IDs (must be 0-n)
|
||||
// note: for tasks with labels, the mapping table must be the same between a training run and a testing run
|
||||
template <class ElemType>
|
||||
void HTKMLFReader<ElemType>::SetLabelMapping(const std::wstring& /*sectionName*/, const std::map<typename IDataReader<ElemType>::LabelIdType, LabelType>& labelMapping)
|
||||
void HTKMLFReader<ElemType>::SetLabelMapping(const std::wstring& /*sectionName*/, const std::map<LabelIdType, LabelType>& labelMapping)
|
||||
{
|
||||
m_idToLabelMap = labelMapping;
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
template <class ElemType>
|
||||
class HTKMLFReader : public IDataReader<ElemType>
|
||||
class HTKMLFReader : public IDataReader
|
||||
{
|
||||
private:
|
||||
msra::dbn::minibatchiterator* m_mbiter;
|
||||
|
@ -71,8 +71,6 @@ private:
|
|||
bool m_noData;
|
||||
|
||||
bool m_trainOrTest; // if false, in file writing mode
|
||||
using LabelType = typename IDataReader<ElemType>::LabelType;
|
||||
using LabelIdType = typename IDataReader<ElemType>::LabelIdType;
|
||||
|
||||
std::map<LabelIdType, LabelType> m_idToLabelMap;
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
template <class ElemType>
|
||||
class HTKMLFWriter : public IDataWriter<ElemType>
|
||||
class HTKMLFWriter : public IDataWriter
|
||||
{
|
||||
private:
|
||||
std::vector<size_t> outputDims;
|
||||
|
|
|
@ -7,23 +7,28 @@
|
|||
|
||||
#include "stdafx.h"
|
||||
#define DATAREADER_EXPORTS
|
||||
#include "DataReader.h"
|
||||
#define DATAWRITER_EXPORTS
|
||||
#include "SequenceReader.h"
|
||||
#include "SequenceWriter.h"
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
template <class ElemType>
|
||||
void DATAREADER_API GetReader(IDataReader<ElemType>** preader)
|
||||
extern "C" DATAREADER_API void GetReaderF(IDataReader** preader)
|
||||
{
|
||||
*preader = new BatchSequenceReader<ElemType>();
|
||||
*preader = new BatchSequenceReader<float>();
|
||||
}
|
||||
extern "C" DATAREADER_API void GetReaderD(IDataReader** preader)
|
||||
{
|
||||
*preader = new BatchSequenceReader<double>();
|
||||
}
|
||||
|
||||
extern "C" DATAREADER_API void GetReaderF(IDataReader<float>** preader)
|
||||
extern "C" DATAWRITER_API void GetWriterF(IDataWriter** pwriter)
|
||||
{
|
||||
GetReader(preader);
|
||||
*pwriter = new LMSequenceWriter<float>();
|
||||
}
|
||||
extern "C" DATAREADER_API void GetReaderD(IDataReader<double>** preader)
|
||||
extern "C" DATAWRITER_API void GetWriterD(IDataWriter** pwriter)
|
||||
{
|
||||
GetReader(preader);
|
||||
*pwriter = new LMSequenceWriter<double>();
|
||||
}
|
||||
} } }
|
||||
|
||||
}}}
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#ifdef LEAKDETECT
|
||||
#include <vld.h> // leak detection
|
||||
#endif
|
||||
#include "DataWriter.h"
|
||||
#include "fileutil.h" // for fexists()
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
@ -68,7 +69,7 @@ size_t SequenceReader<ElemType>::RecordsToRead(size_t mbStartSample, bool tail)
|
|||
// endOfDataCheck - check if we are at the end of the dataset (no wraparound)
|
||||
// returns - true if we have more to read, false if we hit the end of the dataset
|
||||
template <class ElemType>
|
||||
typename IDataReader<ElemType>::LabelIdType SequenceReader<ElemType>::GetIdFromLabel(const std::string& labelValue, LabelInfo& labelInfo)
|
||||
IDataReader::LabelIdType SequenceReader<ElemType>::GetIdFromLabel(const std::string& labelValue, LabelInfo& labelInfo)
|
||||
{
|
||||
auto found = labelInfo.mapLabelToId.find(labelValue);
|
||||
// not yet found, add to the map
|
||||
|
@ -701,11 +702,11 @@ void SequenceReader<ElemType>::InitCache(const ConfigParameters& readerConfig)
|
|||
// mmodify the config so the reader types look correct
|
||||
config["readerType"] = config("writerType");
|
||||
config["file"] = filesList;
|
||||
m_cachingReader = new DataReader<ElemType>(config);
|
||||
m_cachingReader = new DataReader(config);
|
||||
}
|
||||
else
|
||||
{
|
||||
m_cachingWriter = new DataWriter<ElemType>(readerConfig);
|
||||
m_cachingWriter = new DataWriter(readerConfig);
|
||||
|
||||
// now get the section names for map and category types
|
||||
std::map<std::wstring, SectionType, nocase_compare> sections;
|
||||
|
@ -877,8 +878,8 @@ void SequenceReader<ElemType>::StartMinibatchLoop(size_t mbSize, size_t epoch, s
|
|||
{
|
||||
m_labelsBuffer = new ElemType[mbSize * labelInfo.dim];
|
||||
memset(m_labelsBuffer, 0, sizeof(ElemType) * mbSize * labelInfo.dim);
|
||||
m_labelsIdBuffer = new typename IDataReader<ElemType>::LabelIdType[mbSize];
|
||||
memset(m_labelsIdBuffer, 0, sizeof(typename IDataReader<ElemType>::LabelIdType) * mbSize);
|
||||
m_labelsIdBuffer = new LabelIdType[mbSize];
|
||||
memset(m_labelsIdBuffer, 0, sizeof(LabelIdType) * mbSize);
|
||||
}
|
||||
else if (labelInfo.type != labelNone)
|
||||
{
|
||||
|
@ -1188,7 +1189,7 @@ bool SequenceReader<ElemType>::GetMinibatch(StreamMinibatchInputs& matrices)
|
|||
if (labelInfo.type == labelCategory)
|
||||
{
|
||||
memset(m_labelsBuffer, 0, sizeof(ElemType) * labelInfo.dim * actualmbsize);
|
||||
memset(m_labelsIdBuffer, 0, sizeof(typename IDataReader<ElemType>::LabelIdType) * actualmbsize);
|
||||
memset(m_labelsIdBuffer, 0, sizeof(LabelIdType) * actualmbsize);
|
||||
}
|
||||
else if (labelInfo.type != labelNone)
|
||||
{
|
||||
|
@ -1307,7 +1308,7 @@ bool SequenceReader<ElemType>::GetMinibatch(StreamMinibatchInputs& matrices)
|
|||
// GetLabelMapping - Gets the label mapping from integer index to label type
|
||||
// returns - a map from numeric datatype to native label type
|
||||
template <class ElemType>
|
||||
const std::map<typename IDataReader<ElemType>::LabelIdType, typename IDataReader<ElemType>::LabelType>& SequenceReader<ElemType>::GetLabelMapping(const std::wstring& sectionName)
|
||||
const std::map<IDataReader::LabelIdType, IDataReader::LabelType>& SequenceReader<ElemType>::GetLabelMapping(const std::wstring& sectionName)
|
||||
{
|
||||
FailBecauseDeprecated(__FUNCTION__); // DEPRECATED CLASS, SHOULD NOT BE USED ANYMORE
|
||||
|
||||
|
@ -1324,7 +1325,7 @@ const std::map<typename IDataReader<ElemType>::LabelIdType, typename IDataReader
|
|||
// labelMapping - mapping table from label values to IDs (must be 0-n)
|
||||
// note: for tasks with labels, the mapping table must be the same between a training run and a testing run
|
||||
template <class ElemType>
|
||||
void SequenceReader<ElemType>::SetLabelMapping(const std::wstring& /*sectionName*/, const std::map<typename IDataReader<ElemType>::LabelIdType, LabelType>& labelMapping)
|
||||
void SequenceReader<ElemType>::SetLabelMapping(const std::wstring& /*sectionName*/, const std::map<IDataReader::LabelIdType, LabelType>& labelMapping)
|
||||
{
|
||||
FailBecauseDeprecated(__FUNCTION__); // DEPRECATED CLASS, SHOULD NOT BE USED ANYMORE
|
||||
|
||||
|
|
|
@ -109,9 +109,8 @@ public:
|
|||
|
||||
// Note: This class is deprecated for standalone use, only used as a base for BatchSequenceReader which overrides most of the functions.
|
||||
template <class ElemType>
|
||||
class SequenceReader : public IDataReader<ElemType>
|
||||
class SequenceReader : public IDataReader
|
||||
{
|
||||
typedef IDataReader<ElemType> Base;
|
||||
protected:
|
||||
bool m_idx2clsRead;
|
||||
bool m_clsinfoRead;
|
||||
|
@ -119,9 +118,6 @@ protected:
|
|||
bool m_idx2probRead;
|
||||
|
||||
public:
|
||||
using LabelType = typename Base::LabelType;
|
||||
using LabelIdType = typename Base::LabelIdType;
|
||||
|
||||
map<string, int> word4idx;
|
||||
map<int, string> idx4word;
|
||||
map<int, int> idx4class;
|
||||
|
@ -211,8 +207,8 @@ protected:
|
|||
} m_labelInfo[labelInfoNum];
|
||||
|
||||
// caching support
|
||||
DataReader<ElemType>* m_cachingReader;
|
||||
DataWriter<ElemType>* m_cachingWriter;
|
||||
DataReader* m_cachingReader;
|
||||
DataWriter* m_cachingWriter;
|
||||
ConfigParameters m_readerConfig;
|
||||
void InitCache(const ConfigParameters& config);
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
template <class ElemType>
|
||||
class LMSequenceWriter : public IDataWriter<ElemType>
|
||||
class LMSequenceWriter : public IDataWriter
|
||||
{
|
||||
private:
|
||||
std::vector<size_t> outputDims;
|
||||
|
@ -49,8 +49,6 @@ public:
|
|||
}
|
||||
|
||||
public:
|
||||
using LabelType = typename IDataWriter<ElemType>::LabelType;
|
||||
using LabelIdType = typename IDataWriter<ElemType>::LabelIdType;
|
||||
void GetSections(std::map<std::wstring, SectionType, nocase_compare>& /*sections*/)
|
||||
{
|
||||
}
|
||||
|
@ -77,20 +75,4 @@ public:
|
|||
};
|
||||
};
|
||||
|
||||
template <class ElemType>
|
||||
void DATAWRITER_API GetWriter(IDataWriter<ElemType>** pwriter)
|
||||
{
|
||||
assert(pwriter != nullptr);
|
||||
*pwriter = new LMSequenceWriter<ElemType>();
|
||||
assert(*pwriter != nullptr);
|
||||
}
|
||||
|
||||
extern "C" DATAWRITER_API void GetWriterF(IDataWriter<float>** pwriter)
|
||||
{
|
||||
GetWriter(pwriter);
|
||||
}
|
||||
extern "C" DATAWRITER_API void GetWriterD(IDataWriter<double>** pwriter)
|
||||
{
|
||||
GetWriter(pwriter);
|
||||
}
|
||||
} } }
|
||||
}}}
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
//
|
||||
// DataWriter.cpp : Defines the exported functions for the DLL application.
|
||||
//
|
||||
// TODO: Unify with shared DataWriter.cpp.
|
||||
//
|
||||
|
||||
#include "stdafx.h"
|
||||
#include "Basics.h"
|
||||
|
@ -14,59 +16,45 @@
|
|||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
template <class ElemType>
|
||||
void DATAWRITER_API GetWriter(IDataWriter<ElemType>** pwriter)
|
||||
{
|
||||
*pwriter = new LUSequenceWriter<ElemType>();
|
||||
}
|
||||
|
||||
extern "C" DATAWRITER_API void GetWriterF(IDataWriter<float>** pwriter)
|
||||
{
|
||||
GetWriter(pwriter);
|
||||
}
|
||||
extern "C" DATAWRITER_API void GetWriterD(IDataWriter<double>** pwriter)
|
||||
{
|
||||
GetWriter(pwriter);
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
template <class ConfigRecordType>
|
||||
void DataWriter<ElemType>::InitFromConfig(const ConfigRecordType& writerConfig)
|
||||
void DataWriter::InitFromConfig(const ConfigRecordType& writerConfig)
|
||||
{
|
||||
m_dataWriter = new LUSequenceWriter<ElemType>();
|
||||
wstring precision = writerConfig(L"precision", L"float");
|
||||
if (precision == L"float")
|
||||
m_dataWriter = new LUSequenceWriter<float>();
|
||||
else if (precision == L"double")
|
||||
m_dataWriter = new LUSequenceWriter<double>();
|
||||
else
|
||||
InvalidArgument("DataWriter (LUSequenceWriter): The 'precision' parameter must be 'float' or 'double'.");
|
||||
|
||||
m_dataWriter->Init(writerConfig);
|
||||
}
|
||||
|
||||
// Destroy - cleanup and remove this class
|
||||
// NOTE: this destroys the object, and it can't be used past this point
|
||||
template <class ElemType>
|
||||
void DataWriter<ElemType>::Destroy()
|
||||
void DataWriter::Destroy()
|
||||
{
|
||||
delete m_dataWriter;
|
||||
m_dataWriter = NULL;
|
||||
// TODO: don't we need to destroy ourselves?
|
||||
}
|
||||
|
||||
// DataWriter Constructor
|
||||
// config - [in] configuration data for the data writer
|
||||
template <class ElemType>
|
||||
template <class ConfigRecordType>
|
||||
DataWriter<ElemType>::DataWriter(const ConfigRecordType& config)
|
||||
DataWriter::DataWriter(const ConfigRecordType& config)
|
||||
{
|
||||
Init(config);
|
||||
}
|
||||
|
||||
// destructor - cleanup temp files, etc.
|
||||
template <class ElemType>
|
||||
DataWriter<ElemType>::~DataWriter()
|
||||
DataWriter::~DataWriter()
|
||||
{
|
||||
delete m_dataWriter;
|
||||
m_dataWriter = NULL;
|
||||
Destroy();
|
||||
}
|
||||
|
||||
// GetSections - Get the sections of the file
|
||||
// sections - a map of section name to section. Data sepcifications from config file will be used to determine where and how to save data
|
||||
template <class ElemType>
|
||||
void DataWriter<ElemType>::GetSections(std::map<std::wstring, SectionType, nocase_compare>& sections)
|
||||
void DataWriter::GetSections(std::map<std::wstring, SectionType, nocase_compare>& sections)
|
||||
{
|
||||
m_dataWriter->GetSections(sections);
|
||||
}
|
||||
|
@ -77,8 +65,7 @@ void DataWriter<ElemType>::GetSections(std::map<std::wstring, SectionType, nocas
|
|||
// numRecords - number of records we are saving, can be zero if not applicable
|
||||
// datasetSize - Size of the dataset
|
||||
// byteVariableSized - for variable sized data, size of current block to be written, zero when not used, or ignored if not variable sized data
|
||||
template <class ElemType>
|
||||
bool DataWriter<ElemType>::SaveData(size_t recordStart, const std::map<std::wstring, void*, nocase_compare>& matrices, size_t numRecords, size_t datasetSize, size_t byteVariableSized)
|
||||
bool DataWriter::SaveData(size_t recordStart, const std::map<std::wstring, void*, nocase_compare>& matrices, size_t numRecords, size_t datasetSize, size_t byteVariableSized)
|
||||
{
|
||||
return m_dataWriter->SaveData(recordStart, matrices, numRecords, datasetSize, byteVariableSized);
|
||||
}
|
||||
|
@ -86,13 +73,9 @@ bool DataWriter<ElemType>::SaveData(size_t recordStart, const std::map<std::wstr
|
|||
// SaveMapping - save a map into the file
|
||||
// saveId - name of the section to save into (section:subsection format)
|
||||
// labelMapping - map we are saving to the file
|
||||
template <class ElemType>
|
||||
void DataWriter<ElemType>::SaveMapping(std::wstring saveId, const std::map<typename LabelIdType, typename LabelType>& labelMapping)
|
||||
void DataWriter::SaveMapping(std::wstring saveId, const std::map<LabelIdType, LabelType>& labelMapping)
|
||||
{
|
||||
m_dataWriter->SaveMapping(saveId, labelMapping);
|
||||
}
|
||||
|
||||
//The explicit instantiation
|
||||
template class DataWriter<double>;
|
||||
template class DataWriter<float>;
|
||||
} } }
|
||||
}}}
|
|
@ -7,25 +7,28 @@
|
|||
|
||||
#include "stdafx.h"
|
||||
#define DATAREADER_EXPORTS
|
||||
#include "DataReader.h"
|
||||
#define DATAWRITER_EXPORTS
|
||||
#include "LUSequenceReader.h"
|
||||
#include "LUSequenceWriter.h"
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
template <class ElemType>
|
||||
void DATAREADER_API GetReader(IDataReader<ElemType>** preader)
|
||||
extern "C" DATAREADER_API void GetReaderF(IDataReader** preader)
|
||||
{
|
||||
// *preader = new LUSequenceReader<ElemType>();
|
||||
// *preader = new BatchLUSequenceReader<ElemType>();
|
||||
*preader = new MultiIOBatchLUSequenceReader<ElemType>();
|
||||
*preader = new MultiIOBatchLUSequenceReader<float>();
|
||||
}
|
||||
extern "C" DATAREADER_API void GetReaderD(IDataReader** preader)
|
||||
{
|
||||
*preader = new MultiIOBatchLUSequenceReader<double>();
|
||||
}
|
||||
|
||||
extern "C" DATAREADER_API void GetReaderF(IDataReader<float>** preader)
|
||||
extern "C" DATAWRITER_API void GetWriterF(IDataWriter** pwriter)
|
||||
{
|
||||
GetReader(preader);
|
||||
*pwriter = new LUSequenceWriter<float>();
|
||||
}
|
||||
extern "C" DATAREADER_API void GetReaderD(IDataReader<double>** preader)
|
||||
extern "C" DATAWRITER_API void GetWriterD(IDataWriter** pwriter)
|
||||
{
|
||||
GetReader(preader);
|
||||
*pwriter = new LUSequenceWriter<double>();
|
||||
}
|
||||
} } }
|
||||
|
||||
}}}
|
||||
|
|
|
@ -5,6 +5,8 @@
|
|||
// LUSequenceParser.h : Parses the UCI format using a custom state machine (for speed)
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <assert.h>
|
||||
|
|
|
@ -47,7 +47,7 @@ enum ReaderMode
|
|||
};
|
||||
|
||||
template <class ElemType>
|
||||
class LUSequenceReader : public IDataReader<ElemType>
|
||||
class LUSequenceReader : public IDataReader
|
||||
{
|
||||
protected:
|
||||
bool m_idx2clsRead;
|
||||
|
@ -149,8 +149,8 @@ protected:
|
|||
} m_labelInfo[labelInfoNum];
|
||||
|
||||
// caching support
|
||||
DataReader<ElemType>* m_cachingReader;
|
||||
DataWriter<ElemType>* m_cachingWriter;
|
||||
DataReader* m_cachingReader;
|
||||
DataWriter* m_cachingWriter;
|
||||
ConfigParameters m_readerConfig;
|
||||
void InitCache(const ConfigParameters& config);
|
||||
|
||||
|
|
|
@ -120,7 +120,7 @@
|
|||
<ClCompile Include="..\..\Common\Config.cpp">
|
||||
<PrecompiledHeader>NotUsing</PrecompiledHeader>
|
||||
</ClCompile>
|
||||
<ClCompile Include="DataWriter.cpp" />
|
||||
<ClCompile Include="DataWriterLocal.cpp" />
|
||||
<ClCompile Include="Exports.cpp" />
|
||||
<ClCompile Include="dllmain.cpp">
|
||||
<CompileAsManaged Condition="$(DebugBuild)">false</CompileAsManaged>
|
||||
|
|
|
@ -10,9 +10,6 @@
|
|||
<ClCompile Include="..\..\Common\DataReader.cpp">
|
||||
<Filter>Common</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="DataWriter.cpp">
|
||||
<Filter>Common</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="..\..\Common\fileutil.cpp">
|
||||
<Filter>Common</Filter>
|
||||
</ClCompile>
|
||||
|
@ -21,6 +18,9 @@
|
|||
</ClCompile>
|
||||
<ClCompile Include="..\..\Common\Config.cpp" />
|
||||
<ClCompile Include="..\..\Common\ExceptionWithCallStack.cpp" />
|
||||
<ClCompile Include="DataWriterLocal.cpp">
|
||||
<Filter>Common</Filter>
|
||||
</ClCompile>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="LUSequenceWriter.h" />
|
||||
|
@ -51,8 +51,5 @@
|
|||
<Filter Include="Common\Include">
|
||||
<UniqueIdentifier>{85d2fa50-2b95-4ec7-8f2c-c5c0b1cb493e}</UniqueIdentifier>
|
||||
</Filter>
|
||||
<Filter Include="Duplicates to remove">
|
||||
<UniqueIdentifier>{10f819fb-8861-4607-9389-60ca80f968c2}</UniqueIdentifier>
|
||||
</Filter>
|
||||
</ItemGroup>
|
||||
</Project>
|
|
@ -2,17 +2,21 @@
|
|||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "DataWriter.h"
|
||||
#include "LUSequenceParser.h"
|
||||
#include <stdio.h>
|
||||
|
||||
#ifndef MAX_STRING
|
||||
#define MAX_STRING 2048
|
||||
#endif
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
template <class ElemType>
|
||||
class LUSequenceWriter : public IDataWriter<ElemType>
|
||||
class LUSequenceWriter : public IDataWriter
|
||||
{
|
||||
private:
|
||||
std::vector<size_t> outputDims;
|
||||
|
@ -42,7 +46,7 @@ public:
|
|||
void GetSections(std::map<std::wstring, SectionType, nocase_compare>& /*sections*/)
|
||||
{
|
||||
}
|
||||
void SaveMapping(std::wstring saveId, const std::map<typename LabelIdType, typename LabelType>& /*labelMapping*/)
|
||||
void SaveMapping(std::wstring saveId, const std::map<LabelIdType, LabelType>& /*labelMapping*/)
|
||||
{
|
||||
}
|
||||
|
||||
|
|
|
@ -12,18 +12,13 @@
|
|||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
template <class ElemType>
|
||||
void DATAREADER_API GetReader(IDataReader<ElemType>** preader)
|
||||
extern "C" DATAREADER_API void GetReaderF(IDataReader** preader)
|
||||
{
|
||||
*preader = new LibSVMBinaryReader<ElemType>();
|
||||
*preader = new LibSVMBinaryReader<float>();
|
||||
}
|
||||
extern "C" DATAREADER_API void GetReaderD(IDataReader** preader)
|
||||
{
|
||||
*preader = new LibSVMBinaryReader<double>();
|
||||
}
|
||||
|
||||
extern "C" DATAREADER_API void GetReaderF(IDataReader<float>** preader)
|
||||
{
|
||||
GetReader(preader);
|
||||
}
|
||||
extern "C" DATAREADER_API void GetReaderD(IDataReader<double>** preader)
|
||||
{
|
||||
GetReader(preader);
|
||||
}
|
||||
} } }
|
||||
}}}
|
||||
|
|
|
@ -226,12 +226,9 @@ private:
|
|||
};
|
||||
|
||||
template <class ElemType>
|
||||
class LibSVMBinaryReader : public IDataReader<ElemType>
|
||||
class LibSVMBinaryReader : public IDataReader
|
||||
{
|
||||
public:
|
||||
using LabelType = typename IDataReader<ElemType>::LabelType;
|
||||
using LabelIdType = typename IDataReader<ElemType>::LabelIdType;
|
||||
|
||||
virtual void Init(const ConfigParameters& config) override
|
||||
{
|
||||
InitFromConfig(config);
|
||||
|
|
|
@ -19,7 +19,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
typedef ReaderPtr (*ReaderFactory)(const ConfigParameters& parameters);
|
||||
|
||||
template <class ElemType>
|
||||
class ReaderShim : public IDataReader<ElemType>
|
||||
class ReaderShim : public IDataReader
|
||||
{
|
||||
public:
|
||||
explicit ReaderShim(ReaderFactory factory);
|
||||
|
|
|
@ -12,18 +12,13 @@
|
|||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
template <class ElemType>
|
||||
void DATAREADER_API GetReader(IDataReader<ElemType>** preader)
|
||||
extern "C" DATAREADER_API void GetReaderF(IDataReader** preader)
|
||||
{
|
||||
*preader = new SparsePCReader<ElemType>();
|
||||
*preader = new SparsePCReader<float>();
|
||||
}
|
||||
extern "C" DATAREADER_API void GetReaderD(IDataReader** preader)
|
||||
{
|
||||
*preader = new SparsePCReader<double>();
|
||||
}
|
||||
|
||||
extern "C" DATAREADER_API void GetReaderF(IDataReader<float>** preader)
|
||||
{
|
||||
GetReader(preader);
|
||||
}
|
||||
extern "C" DATAREADER_API void GetReaderD(IDataReader<double>** preader)
|
||||
{
|
||||
GetReader(preader);
|
||||
}
|
||||
} } }
|
||||
}}}
|
||||
|
|
|
@ -317,7 +317,7 @@ bool SparsePCReader<ElemType>::DataEnd() { return true; }
|
|||
// GetLabelMapping - Gets the label mapping from integer index to label type
|
||||
// returns - a map from numeric datatype to native label type
|
||||
template <class ElemType>
|
||||
const std::map<typename IDataReader<ElemType>::LabelIdType, typename IDataReader<ElemType>::LabelType>& SparsePCReader<ElemType>::GetLabelMapping(const std::wstring& /*sectionName*/)
|
||||
const std::map<IDataReader::LabelIdType, IDataReader::LabelType>& SparsePCReader<ElemType>::GetLabelMapping(const std::wstring& /*sectionName*/)
|
||||
{
|
||||
return m_mapIdToLabel;
|
||||
}
|
||||
|
@ -326,7 +326,7 @@ const std::map<typename IDataReader<ElemType>::LabelIdType, typename IDataReader
|
|||
// labelMapping - mapping table from label values to IDs (must be 0-n)
|
||||
// note: for tasks with labels, the mapping table must be the same between a training run and a testing run
|
||||
template <class ElemType>
|
||||
void SparsePCReader<ElemType>::SetLabelMapping(const std::wstring& /*sectionName*/, const std::map<typename IDataReader<ElemType>::LabelIdType, LabelType>& labelMapping)
|
||||
void SparsePCReader<ElemType>::SetLabelMapping(const std::wstring& /*sectionName*/, const std::map<IDataReader::LabelIdType, LabelType>& labelMapping)
|
||||
{
|
||||
m_mapIdToLabel = labelMapping;
|
||||
m_mapLabelToId.clear();
|
||||
|
|
|
@ -21,12 +21,8 @@
|
|||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
template <class ElemType>
|
||||
class SparsePCReader : public IDataReader<ElemType>
|
||||
class SparsePCReader : public IDataReader
|
||||
{
|
||||
public:
|
||||
using LabelType = typename IDataReader<ElemType>::LabelType;
|
||||
using LabelIdType = typename IDataReader<ElemType>::LabelIdType;
|
||||
private:
|
||||
ConfigParameters m_readerConfig;
|
||||
std::wstring m_file;
|
||||
size_t m_featureCount;
|
||||
|
|
|
@ -12,18 +12,13 @@
|
|||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
template <class ElemType>
|
||||
void DATAREADER_API GetReader(IDataReader<ElemType>** preader)
|
||||
extern "C" DATAREADER_API void GetReaderF(IDataReader** preader)
|
||||
{
|
||||
*preader = new UCIFastReader<ElemType>();
|
||||
*preader = new UCIFastReader<float>();
|
||||
}
|
||||
extern "C" DATAREADER_API void GetReaderD(IDataReader** preader)
|
||||
{
|
||||
*preader = new UCIFastReader<double>();
|
||||
}
|
||||
|
||||
extern "C" DATAREADER_API void GetReaderF(IDataReader<float>** preader)
|
||||
{
|
||||
GetReader(preader);
|
||||
}
|
||||
extern "C" DATAREADER_API void GetReaderD(IDataReader<double>** preader)
|
||||
{
|
||||
GetReader(preader);
|
||||
}
|
||||
} } }
|
||||
}}}
|
||||
|
|
|
@ -500,11 +500,11 @@ void UCIFastReader<ElemType>::InitCache(const ConfigParameters& readerConfig)
|
|||
// mmodify the config so the reader types look correct
|
||||
config["readerType"] = config("writerType");
|
||||
config["file"] = filesList;
|
||||
m_cachingReader = new DataReader<ElemType>(config);
|
||||
m_cachingReader = new DataReader(config);
|
||||
}
|
||||
else
|
||||
{
|
||||
m_cachingWriter = new DataWriter<ElemType>(readerConfig);
|
||||
m_cachingWriter = new DataWriter(readerConfig);
|
||||
|
||||
// now get the section names for map and category types
|
||||
std::map<std::wstring, SectionType, nocase_compare> sections;
|
||||
|
@ -1015,7 +1015,7 @@ bool UCIFastReader<ElemType>::GetMinibatchImpl(StreamMinibatchInputs& matrices)
|
|||
// GetLabelMapping - Gets the label mapping from integer index to label type
|
||||
// returns - a map from numeric datatype to native label type
|
||||
template <class ElemType>
|
||||
const std::map<typename IDataReader<ElemType>::LabelIdType, typename IDataReader<ElemType>::LabelType>& UCIFastReader<ElemType>::GetLabelMapping(const std::wstring& sectionName)
|
||||
const std::map<IDataReader::LabelIdType, IDataReader::LabelType>& UCIFastReader<ElemType>::GetLabelMapping(const std::wstring& sectionName)
|
||||
{
|
||||
if (m_cachingReader)
|
||||
{
|
||||
|
@ -1028,7 +1028,7 @@ const std::map<typename IDataReader<ElemType>::LabelIdType, typename IDataReader
|
|||
// labelMapping - mapping table from label values to IDs (must be 0-n)
|
||||
// note: for tasks with labels, the mapping table must be the same between a training run and a testing run
|
||||
template <class ElemType>
|
||||
void UCIFastReader<ElemType>::SetLabelMapping(const std::wstring& /*sectionName*/, const std::map<typename IDataReader<ElemType>::LabelIdType, LabelType>& labelMapping)
|
||||
void UCIFastReader<ElemType>::SetLabelMapping(const std::wstring& /*sectionName*/, const std::map<LabelIdType, LabelType>& labelMapping)
|
||||
{
|
||||
if (m_cachingReader)
|
||||
{
|
||||
|
|
|
@ -36,15 +36,8 @@ enum LabelKind
|
|||
};
|
||||
|
||||
template <class ElemType>
|
||||
class UCIFastReader : public IDataReader<ElemType>
|
||||
class UCIFastReader : public IDataReader
|
||||
{
|
||||
public:
|
||||
using LabelType = typename IDataReader<ElemType>::LabelType;
|
||||
using LabelIdType = typename IDataReader<ElemType>::LabelIdType;
|
||||
using IDataReader<ElemType>::mRequestedNumParallelSequences;
|
||||
// typedef std::string LabelType;
|
||||
// typedef unsigned LabelIdType;
|
||||
private:
|
||||
shared_ptr<UCIParser<ElemType, LabelType>> m_parser;
|
||||
size_t m_mbSize; // size of minibatch requested
|
||||
LabelIdType m_labelIdMax; // maximum label ID we have encountered so far
|
||||
|
@ -102,8 +95,8 @@ private:
|
|||
unique_ptr<CUDAPageLockedMemAllocator> m_cudaAllocator;
|
||||
|
||||
// caching support
|
||||
DataReader<ElemType>* m_cachingReader;
|
||||
DataWriter<ElemType>* m_cachingWriter;
|
||||
DataReader* m_cachingReader;
|
||||
DataWriter* m_cachingWriter;
|
||||
ConfigParameters m_readerConfig;
|
||||
void InitCache(const ConfigParameters& config);
|
||||
void InitCache(const ScriptableObjects::IConfigRecord& config);
|
||||
|
|
|
@ -24,7 +24,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
// Note: This will go away with the redesigned reader interface.
|
||||
// TODO: callers of this often do ComputationNetwork::BumpEvalTimeStamp(featureNodes) and also for labels; we should eliminate the need for this.
|
||||
template <class ElemType>
|
||||
static bool GetMinibatchIntoNetwork(IDataReader<ElemType>& trainSetDataReader,
|
||||
static bool GetMinibatchIntoNetwork(IDataReader& trainSetDataReader,
|
||||
ComputationNetworkPtr net,
|
||||
ComputationNodeBasePtr criterionNode,
|
||||
bool useDistributedMBReading,
|
||||
|
@ -345,7 +345,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
}
|
||||
}
|
||||
|
||||
size_t GetMinibatchIntoCache(IDataReader<ElemType>& trainSetDataReader,
|
||||
size_t GetMinibatchIntoCache(IDataReader& trainSetDataReader,
|
||||
ComputationNetwork& net,
|
||||
StreamMinibatchInputs& inputMatrices,
|
||||
size_t requestedSubminibatches)
|
||||
|
|
|
@ -36,8 +36,8 @@ template SGD<double>::SGD(const ScriptableObjects::IConfigRecord&);
|
|||
|
||||
template <class ElemType>
|
||||
void SGD<ElemType>::Train(function<ComputationNetworkPtr(DEVICEID_TYPE)> createNetworkFn, DEVICEID_TYPE deviceId,
|
||||
IDataReader<ElemType>* trainSetDataReader,
|
||||
IDataReader<ElemType>* validationSetDataReader,
|
||||
IDataReader* trainSetDataReader,
|
||||
IDataReader* validationSetDataReader,
|
||||
const bool makeMode)
|
||||
{
|
||||
// determine which epoch to start with, including recovering a checkpoint if any and 'makeMode' enabled
|
||||
|
@ -80,8 +80,8 @@ void SGD<ElemType>::Train(function<ComputationNetworkPtr(DEVICEID_TYPE)> createN
|
|||
|
||||
template <class ElemType>
|
||||
void SGD<ElemType>::Adapt(wstring origModelFileName, wstring refNodeName,
|
||||
IDataReader<ElemType>* trainSetDataReader,
|
||||
IDataReader<ElemType>* validationSetDataReader,
|
||||
IDataReader* trainSetDataReader,
|
||||
IDataReader* validationSetDataReader,
|
||||
const DEVICEID_TYPE deviceId, const bool makeMode)
|
||||
{
|
||||
int startEpoch = DetermineStartEpoch(makeMode);
|
||||
|
@ -139,8 +139,8 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
|
|||
bool networkLoadedFromCheckpoint,
|
||||
ComputationNetworkPtr refNet,
|
||||
ComputationNodeBasePtr refNode,
|
||||
IDataReader<ElemType>* trainSetDataReader,
|
||||
IDataReader<ElemType>* validationSetDataReader)
|
||||
IDataReader* trainSetDataReader,
|
||||
IDataReader* validationSetDataReader)
|
||||
{
|
||||
auto& featureNodes = net->FeatureNodes();
|
||||
auto& labelNodes = net->LabelNodes();
|
||||
|
@ -188,7 +188,7 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
|
|||
{
|
||||
auto& nodes = (pass == 0) ? featureNodes : labelNodes;
|
||||
for (const auto & node : nodes)
|
||||
(*inputMatrices).AddInputMatrix(node->NodeName(), dynamic_pointer_cast<ComputationNode<ElemType>>(node)->ValuePtr());
|
||||
(*inputMatrices).AddInputMatrix(node->NodeName(), node->ValuePtr());
|
||||
}
|
||||
|
||||
// get hmm file for sequence training
|
||||
|
@ -699,7 +699,7 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
|
|||
const ComputationNodeBasePtr& refNode,
|
||||
const int epochNumber,
|
||||
const size_t epochSize,
|
||||
IDataReader<ElemType>* trainSetDataReader,
|
||||
IDataReader* trainSetDataReader,
|
||||
const double learnRatePerSample,
|
||||
size_t tunedMBSize,
|
||||
const std::vector<ComputationNodeBasePtr>& featureNodes,
|
||||
|
@ -843,7 +843,7 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
|
|||
// get minibatch
|
||||
// TODO: is it guaranteed that the GPU is already completed at this point, is it safe to overwrite the buffers?
|
||||
size_t actualMBSize = 0;
|
||||
bool wasDataRead = DataReaderHelpers::GetMinibatchIntoNetwork(*trainSetDataReader, net, criterionNodes[0],
|
||||
bool wasDataRead = DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*trainSetDataReader, net, criterionNodes[0],
|
||||
useDistributedMBReading, useParallelTrain, *inputMatrices, actualMBSize);
|
||||
if (!wasDataRead && (!useDistributedMBReading || noMoreSamplesToProcess)) // in case of distributed reading, we do a few more loops until all ranks have completed
|
||||
break; // end of epoch
|
||||
|
@ -1277,7 +1277,7 @@ std::vector<ComputationNodeBasePtr>& SGD<ElemType>::GetEvalCriterionNodes(Comput
|
|||
// Returns true if precomputation was executed.
|
||||
template <class ElemType>
|
||||
bool SGD<ElemType>::PreCompute(ComputationNetworkPtr net,
|
||||
IDataReader<ElemType>* trainSetDataReader,
|
||||
IDataReader* trainSetDataReader,
|
||||
std::vector<ComputationNodeBasePtr>& featureNodes,
|
||||
std::vector<ComputationNodeBasePtr>& labelNodes,
|
||||
StreamMinibatchInputs* inputMatrices)
|
||||
|
@ -1312,7 +1312,7 @@ bool SGD<ElemType>::PreCompute(ComputationNetworkPtr net,
|
|||
const size_t numIterationsBeforePrintingProgress = 100;
|
||||
size_t numItersSinceLastPrintOfProgress = 0;
|
||||
size_t actualMBSizeDummy;
|
||||
while (DataReaderHelpers::GetMinibatchIntoNetwork(*trainSetDataReader, net, nullptr, false, false, *inputMatrices, actualMBSizeDummy))
|
||||
while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*trainSetDataReader, net, nullptr, false, false, *inputMatrices, actualMBSizeDummy))
|
||||
{
|
||||
// TODO: move these into GetMinibatchIntoNetwork() --but those are passed around; necessary? Can't we get them from 'net'?
|
||||
ComputationNetwork::BumpEvalTimeStamp(featureNodes);
|
||||
|
@ -1347,7 +1347,7 @@ double SGD<ElemType>::SearchForBestLearnRate(ComputationNetworkPtr net,
|
|||
ComputationNetworkPtr refNet,
|
||||
const ComputationNodeBasePtr& refNode, const int epochNumber,
|
||||
const double curLearnRate,
|
||||
IDataReader<ElemType>* trainSetDataReader,
|
||||
IDataReader* trainSetDataReader,
|
||||
const std::vector<ComputationNodeBasePtr>& featureNodes,
|
||||
const std::vector<ComputationNodeBasePtr>& labelNodes,
|
||||
const std::vector<ComputationNodeBasePtr>& criterionNodes,
|
||||
|
@ -1514,7 +1514,7 @@ size_t SGD<ElemType>::AdaptiveMinibatchSizing(ComputationNetworkPtr net,
|
|||
const ComputationNodeBasePtr& refNode,
|
||||
const int epochNumber,
|
||||
const size_t numFramesToUseInSearch,
|
||||
IDataReader<ElemType>* trainSetDataReader,
|
||||
IDataReader* trainSetDataReader,
|
||||
const double learnRatePerSample,
|
||||
const size_t initialMinibatchSize,
|
||||
const std::vector<ComputationNodeBasePtr>& featureNodes,
|
||||
|
@ -1620,7 +1620,7 @@ size_t SGD<ElemType>::SearchForBestMinibatchSize(ComputationNetworkPtr net,
|
|||
const ComputationNodeBasePtr& refNode,
|
||||
const int epochNumber,
|
||||
const size_t numFramesToUseInSearch,
|
||||
IDataReader<ElemType>* trainSetDataReader,
|
||||
IDataReader* trainSetDataReader,
|
||||
const double learnRatePerSample,
|
||||
const std::vector<ComputationNodeBasePtr>& featureNodes,
|
||||
const std::vector<ComputationNodeBasePtr>& labelNodes,
|
||||
|
@ -1716,7 +1716,7 @@ template <class ElemType>
|
|||
void SGD<ElemType>::TrainOneMiniEpochAndReloadModel(ComputationNetworkPtr net,
|
||||
ComputationNetworkPtr refNet,
|
||||
const ComputationNodeBasePtr& refNode, const int epochNumber,
|
||||
const size_t epochSize, IDataReader<ElemType>* trainSetDataReader,
|
||||
const size_t epochSize, IDataReader* trainSetDataReader,
|
||||
const double learnRatePerSample,
|
||||
const size_t minibatchSize,
|
||||
const std::vector<ComputationNodeBasePtr>& featureNodes,
|
||||
|
@ -1773,7 +1773,7 @@ void SGD<ElemType>::TrainOneMiniEpochAndReloadModel(ComputationNetworkPtr net,
|
|||
// TODO: move the two-forward-pass support out of the reader.
|
||||
template <class ElemType>
|
||||
void SGD<ElemType>::AttemptUtteranceDerivativeFeatures(ComputationNetworkPtr net,
|
||||
IDataReader<ElemType>* trainSetDataReader,
|
||||
IDataReader* trainSetDataReader,
|
||||
const std::vector<ComputationNodeBasePtr>& featureNodes,
|
||||
StreamMinibatchInputs* inputMatrices)
|
||||
{
|
||||
|
|
|
@ -296,12 +296,12 @@ public:
|
|||
}
|
||||
|
||||
void Train(function<ComputationNetworkPtr(DEVICEID_TYPE)> createNetworkFn, DEVICEID_TYPE deviceId,
|
||||
IDataReader<ElemType>* trainSetDataReader,
|
||||
IDataReader<ElemType>* validationSetDataReader,
|
||||
IDataReader* trainSetDataReader,
|
||||
IDataReader* validationSetDataReader,
|
||||
const bool makeMode = true);
|
||||
void Adapt(wstring origModelFileName, wstring refNodeName,
|
||||
IDataReader<ElemType>* trainSetDataReader,
|
||||
IDataReader<ElemType>* validationSetDataReader,
|
||||
IDataReader* trainSetDataReader,
|
||||
IDataReader* validationSetDataReader,
|
||||
const DEVICEID_TYPE deviceID, const bool makeMode = true);
|
||||
|
||||
protected:
|
||||
|
@ -313,14 +313,14 @@ protected:
|
|||
bool networkLoadedFromCheckpoint,
|
||||
ComputationNetworkPtr refNet,
|
||||
ComputationNodeBasePtr refNode,
|
||||
IDataReader<ElemType>* trainSetDataReader,
|
||||
IDataReader<ElemType>* validationSetDataReader);
|
||||
IDataReader* trainSetDataReader,
|
||||
IDataReader* validationSetDataReader);
|
||||
|
||||
protected:
|
||||
|
||||
// return true if precomputation is executed.
|
||||
bool PreCompute(ComputationNetworkPtr net,
|
||||
IDataReader<ElemType>* trainSetDataReader,
|
||||
IDataReader* trainSetDataReader,
|
||||
std::vector<ComputationNodeBasePtr>& featureNodes,
|
||||
std::vector<ComputationNodeBasePtr>& labelNodes,
|
||||
StreamMinibatchInputs* inputMatrices);
|
||||
|
@ -330,7 +330,7 @@ protected:
|
|||
ComputationNetworkPtr refNet,
|
||||
const ComputationNodeBasePtr& refNode, const int epochNumber,
|
||||
const double curLearnRate,
|
||||
IDataReader<ElemType>* trainSetDataReader,
|
||||
IDataReader* trainSetDataReader,
|
||||
const std::vector<ComputationNodeBasePtr>& featureNodes,
|
||||
const std::vector<ComputationNodeBasePtr>& labelNodes,
|
||||
const std::vector<ComputationNodeBasePtr>& criterionNodes,
|
||||
|
@ -344,7 +344,7 @@ protected:
|
|||
void TrainOneMiniEpochAndReloadModel(ComputationNetworkPtr net,
|
||||
ComputationNetworkPtr refNet,
|
||||
const ComputationNodeBasePtr& refNode, const int epochNumber,
|
||||
const size_t epochSize, IDataReader<ElemType>* trainSetDataReader,
|
||||
const size_t epochSize, IDataReader* trainSetDataReader,
|
||||
const double learnRatePerSample,
|
||||
const size_t minibatchSize,
|
||||
const std::vector<ComputationNodeBasePtr>& featureNodes,
|
||||
|
@ -364,7 +364,7 @@ protected:
|
|||
const ComputationNodeBasePtr& refNode,
|
||||
const int epochNumber,
|
||||
const size_t numFramesToUseInSearch,
|
||||
IDataReader<ElemType>* trainSetDataReader,
|
||||
IDataReader* trainSetDataReader,
|
||||
const double learnRatePerSample,
|
||||
const size_t initialMinibatchSize,
|
||||
const std::vector<ComputationNodeBasePtr>& featureNodes,
|
||||
|
@ -383,7 +383,7 @@ protected:
|
|||
const ComputationNodeBasePtr& refNode,
|
||||
const int epochNumber,
|
||||
const size_t numFramesToUseInSearch,
|
||||
IDataReader<ElemType>* trainSetDataReader,
|
||||
IDataReader* trainSetDataReader,
|
||||
const double learnRatePerSample,
|
||||
const std::vector<ComputationNodeBasePtr>& featureNodes,
|
||||
const std::vector<ComputationNodeBasePtr>& labelNodes,
|
||||
|
@ -400,7 +400,7 @@ protected:
|
|||
// processing more utterances at the same time. Only used in Kaldi2Reader.
|
||||
// TODO: move the two-forward-pass support out of the reader.
|
||||
void AttemptUtteranceDerivativeFeatures(ComputationNetworkPtr net,
|
||||
IDataReader<ElemType>* trainSetDataReader,
|
||||
IDataReader* trainSetDataReader,
|
||||
const std::vector<ComputationNodeBasePtr>& featureNodes,
|
||||
StreamMinibatchInputs* inputMatrices);
|
||||
|
||||
|
@ -409,7 +409,7 @@ protected:
|
|||
const ComputationNodeBasePtr& refNode,
|
||||
const int epochNumber,
|
||||
const size_t epochSize,
|
||||
IDataReader<ElemType>* trainSetDataReader,
|
||||
IDataReader* trainSetDataReader,
|
||||
const double learnRatePerSample,
|
||||
size_t tunedMBSize,
|
||||
const std::vector<ComputationNodeBasePtr>& featureNodes,
|
||||
|
|
|
@ -30,7 +30,7 @@ public:
|
|||
}
|
||||
|
||||
// returns evaluation node values per sample determined by evalNodeNames (which can include both training and eval criterion nodes)
|
||||
vector<double> Evaluate(IDataReader<ElemType>* dataReader, const vector<wstring>& evalNodeNames, const size_t mbSize, const size_t testSize = requestDataSize)
|
||||
vector<double> Evaluate(IDataReader* dataReader, const vector<wstring>& evalNodeNames, const size_t mbSize, const size_t testSize = requestDataSize)
|
||||
{
|
||||
// determine nodes to evaluate
|
||||
std::vector<ComputationNodeBasePtr> evalNodes;
|
||||
|
@ -77,9 +77,9 @@ public:
|
|||
|
||||
StreamMinibatchInputs inputMatrices;
|
||||
for (auto& node : featureNodes)
|
||||
inputMatrices.AddInputMatrix(node->NodeName(), node->As<ComputationNode<ElemType>>()->ValuePtr());
|
||||
inputMatrices.AddInputMatrix(node->NodeName(), node->ValuePtr());
|
||||
for (auto& node : labelNodes)
|
||||
inputMatrices.AddInputMatrix(node->NodeName(), node->As<ComputationNode<ElemType>>()->ValuePtr());
|
||||
inputMatrices.AddInputMatrix(node->NodeName(), node->ValuePtr());
|
||||
|
||||
// evaluate through minibatches
|
||||
size_t totalEpochSamples = 0;
|
||||
|
@ -95,7 +95,7 @@ public:
|
|||
dataReader->StartMinibatchLoop(mbSize, 0, testSize);
|
||||
m_net->StartEvaluateMinibatchLoop(evalNodes);
|
||||
|
||||
while (DataReaderHelpers::GetMinibatchIntoNetwork(*dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize))
|
||||
while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize))
|
||||
{
|
||||
ComputationNetwork::BumpEvalTimeStamp(featureNodes);
|
||||
ComputationNetwork::BumpEvalTimeStamp(labelNodes);
|
||||
|
|
|
@ -74,7 +74,7 @@ private:
|
|||
{
|
||||
StreamMinibatchInputs inputMatrices;
|
||||
for (auto& node : inputNodes)
|
||||
inputMatrices.AddInputMatrix(node->NodeName(), node->As<ComputationNode<ElemType>>()->ValuePtr());
|
||||
inputMatrices.AddInputMatrix(node->NodeName(), node->ValuePtr());
|
||||
return inputMatrices;
|
||||
}
|
||||
|
||||
|
@ -84,7 +84,7 @@ public:
|
|||
{
|
||||
}
|
||||
|
||||
void WriteOutput(IDataReader<ElemType>& dataReader, size_t mbSize, IDataWriter<ElemType>& dataWriter, const std::vector<std::wstring>& outputNodeNames, size_t numOutputSamples = requestDataSize, bool doUnitTest = false)
|
||||
void WriteOutput(IDataReader& dataReader, size_t mbSize, IDataWriter& dataWriter, const std::vector<std::wstring>& outputNodeNames, size_t numOutputSamples = requestDataSize, bool doUnitTest = false)
|
||||
{
|
||||
std::vector<ComputationNodeBasePtr> outputNodes = DetermineOutputNodes(outputNodeNames);
|
||||
std::vector<ComputationNodeBasePtr> inputNodes = DetermineInputNodes(outputNodes);
|
||||
|
@ -104,7 +104,7 @@ public:
|
|||
std::map<std::wstring, void*, nocase_compare> outputMatrices;
|
||||
|
||||
size_t actualMBSize;
|
||||
while (DataReaderHelpers::GetMinibatchIntoNetwork(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize))
|
||||
while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize))
|
||||
{
|
||||
ComputationNetwork::BumpEvalTimeStamp(inputNodes);
|
||||
|
||||
|
@ -174,7 +174,7 @@ public:
|
|||
};
|
||||
|
||||
// TODO: Remove code dup with above function by creating a fake Writer object and then calling the other function.
|
||||
void WriteOutput(IDataReader<ElemType>& dataReader, size_t mbSize, std::wstring outputPath, const std::vector<std::wstring>& outputNodeNames, const WriteFormattingOptions & formattingOptions, size_t numOutputSamples = requestDataSize)
|
||||
void WriteOutput(IDataReader& dataReader, size_t mbSize, std::wstring outputPath, const std::vector<std::wstring>& outputNodeNames, const WriteFormattingOptions & formattingOptions, size_t numOutputSamples = requestDataSize)
|
||||
{
|
||||
std::vector<ComputationNodeBasePtr> outputNodes = DetermineOutputNodes(outputNodeNames);
|
||||
std::vector<ComputationNodeBasePtr> inputNodes = DetermineInputNodes(outputNodes);
|
||||
|
@ -221,7 +221,7 @@ public:
|
|||
std::string valueFormatString = "%" + formattingOptions.precisionFormat + formatChar; // format string used in fprintf() for formatting the values
|
||||
|
||||
size_t actualMBSize;
|
||||
while (DataReaderHelpers::GetMinibatchIntoNetwork(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize))
|
||||
while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize))
|
||||
{
|
||||
ComputationNetwork::BumpEvalTimeStamp(inputNodes);
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ void DoCommand(const ConfigParameters& configRoot)
|
|||
|
||||
Eval<ElemType> eval(config);
|
||||
|
||||
DataReader<ElemType>* dataReader = new DataReader<ElemType>(readerConfig);
|
||||
auto dataReader = make_shared<DataReader>(readerConfig);
|
||||
eval.LoadModel(modelPath);
|
||||
dataReader->StartMinibatchLoop(mbSize, 0, epochSize);
|
||||
eval.StartEvaluateMinibatchLoop(outputName);
|
||||
|
|
|
@ -146,7 +146,7 @@ struct ReaderFixture
|
|||
template <class ElemType>
|
||||
void HelperWriteReaderContentToFile(
|
||||
ofstream& outputFile,
|
||||
DataReader<ElemType>& dataReader,
|
||||
DataReader& dataReader,
|
||||
StreamMinibatchInputs& map,
|
||||
size_t epochs,
|
||||
size_t mbSize,
|
||||
|
@ -225,7 +225,7 @@ struct ReaderFixture
|
|||
const ConfigParameters simpleDemoConfig = config(testSectionName);
|
||||
const ConfigParameters readerConfig = simpleDemoConfig(readerSectionName);
|
||||
|
||||
DataReader<ElemType> dataReader(readerConfig);
|
||||
DataReader dataReader(readerConfig);
|
||||
|
||||
StreamMinibatchInputs map;
|
||||
std::vector<shared_ptr<Matrix<ElemType>>> features;
|
||||
|
@ -250,7 +250,7 @@ struct ReaderFixture
|
|||
ofstream outputFile(testDataFilePath, ios::out);
|
||||
|
||||
// Perform the data reading
|
||||
HelperWriteReaderContentToFile(outputFile, dataReader, map, epochs, mbSize, epochSize, numFeatureFiles, numLabelFiles, subsetNum, numSubsets);
|
||||
HelperWriteReaderContentToFile<ElemType>(outputFile, dataReader, map, epochs, mbSize, epochSize, numFeatureFiles, numLabelFiles, subsetNum, numSubsets);
|
||||
|
||||
outputFile.close();
|
||||
|
||||
|
@ -286,8 +286,9 @@ struct ReaderFixture
|
|||
const ConfigParameters simpleDemoConfig = config(testSectionName);
|
||||
const ConfigParameters readerConfig = simpleDemoConfig(readerSectionName);
|
||||
|
||||
BOOST_CHECK_THROW(DataReader<ElemType> dataReader(readerConfig), ExceptionType);
|
||||
BOOST_CHECK_THROW(DataReader dataReader(readerConfig), ExceptionType);
|
||||
}
|
||||
};
|
||||
}
|
||||
} } }
|
||||
|
||||
}}}
|
||||
|
|
Загрузка…
Ссылка в новой задаче