2014-08-30 03:21:42 +04:00
|
|
|
//
|
|
|
|
// <copyright file="EvalReader.h" company="Microsoft">
|
|
|
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
|
|
// </copyright>
|
|
|
|
//
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
#define DATAREADER_LOCAL
|
|
|
|
#include "DataReader.h"
|
|
|
|
|
|
|
|
namespace Microsoft { namespace MSR { namespace CNTK {
|
|
|
|
|
|
|
|
// Evaluation Reader class
|
|
|
|
// interface to pass to evaluation DLL
|
|
|
|
template<class ElemType>
|
|
|
|
class EvalReader : public IDataReader<ElemType>
|
|
|
|
{
|
2014-11-05 07:07:35 +03:00
|
|
|
typedef typename IDataReader<ElemType>::LabelType LabelType;
|
|
|
|
typedef typename IDataReader<ElemType>::LabelIdType LabelIdType;
|
2014-08-30 03:21:42 +04:00
|
|
|
private:
|
|
|
|
std::map<std::wstring, std::vector<ElemType>*>* m_inputs; // our input data
|
|
|
|
std::map<std::wstring, size_t>* m_dimensions; // the number of rows for the input data
|
|
|
|
size_t m_recordCount; // count of records in this data
|
|
|
|
size_t m_currentRecord; // next record number to read
|
|
|
|
size_t m_mbSize;
|
2014-11-07 06:24:05 +03:00
|
|
|
vector<size_t> m_switchFrame;
|
|
|
|
size_t m_oldSig;
|
2014-08-30 03:21:42 +04:00
|
|
|
public:
|
|
|
|
// Method to setup the data for the reader
|
|
|
|
void SetData(std::map<std::wstring, std::vector<ElemType>*>* inputs, std::map<std::wstring, size_t>* dimensions)
|
|
|
|
{
|
|
|
|
m_inputs = inputs;
|
|
|
|
m_dimensions = dimensions;
|
|
|
|
m_currentRecord = 0;
|
|
|
|
m_recordCount = 0;
|
2014-11-05 00:47:58 +03:00
|
|
|
for (auto iter = inputs->begin(); iter != inputs->end(); ++iter)
|
2014-08-30 03:21:42 +04:00
|
|
|
{
|
|
|
|
// figure out the dimension of the data
|
|
|
|
const std::wstring& val = iter->first;
|
|
|
|
size_t count = (*inputs)[val]->size();
|
|
|
|
size_t rows = (*dimensions)[val];
|
|
|
|
size_t recordCount = count/rows;
|
|
|
|
|
|
|
|
|
|
|
|
if (m_recordCount != 0)
|
|
|
|
{
|
|
|
|
// record count must be the same for all the data
|
|
|
|
if (recordCount != m_recordCount)
|
2014-10-30 20:33:51 +03:00
|
|
|
RuntimeError("Record Count of %ls (%lux%lu) does not match the record count of previous entries (%lu).", val.c_str(), rows, recordCount, m_recordCount);
|
2014-08-30 03:21:42 +04:00
|
|
|
}
|
|
|
|
else
|
|
|
|
{
|
|
|
|
m_recordCount = recordCount;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2014-11-07 06:24:05 +03:00
|
|
|
void SetBoundary (size_t newSig)
|
|
|
|
{
|
|
|
|
if (m_switchFrame.size()==0)
|
|
|
|
{
|
|
|
|
m_oldSig = newSig;
|
|
|
|
m_switchFrame.assign(1,0);
|
|
|
|
} else
|
|
|
|
{
|
|
|
|
if (m_oldSig==newSig)
|
|
|
|
{
|
2015-09-26 02:24:02 +03:00
|
|
|
m_switchFrame[0] = m_mbSize+8888; // TODO: WTF??
|
2014-11-07 06:24:05 +03:00
|
|
|
}
|
|
|
|
else
|
|
|
|
{
|
|
|
|
m_switchFrame[0] = 0;
|
|
|
|
m_oldSig = newSig;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
2014-08-30 03:21:42 +04:00
|
|
|
|
2014-10-15 09:39:19 +04:00
|
|
|
virtual void Init(const ConfigParameters& /*config*/)
|
2014-08-30 03:21:42 +04:00
|
|
|
{
|
|
|
|
}
|
|
|
|
|
|
|
|
// Destroy - cleanup and remove this class
|
|
|
|
// NOTE: this destroys the object, and it can't be used past this point
|
|
|
|
virtual void Destroy()
|
|
|
|
{
|
|
|
|
delete this;
|
|
|
|
}
|
|
|
|
|
|
|
|
// EvalReader Constructor
|
|
|
|
// config - [in] configuration parameters for the datareader
|
|
|
|
EvalReader(const ConfigParameters& config)
|
|
|
|
{
|
|
|
|
m_recordCount = m_currentRecord = 0;
|
|
|
|
Init(config);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Destructor - free up the matrix values we allocated
|
|
|
|
virtual ~EvalReader()
|
|
|
|
{
|
|
|
|
}
|
|
|
|
|
|
|
|
//StartMinibatchLoop - Startup a minibatch loop
|
|
|
|
// mbSize - [in] size of the minibatch (number of frames, etc.)
|
|
|
|
// epoch - [in] epoch number for this loop
|
|
|
|
// requestedEpochSamples - [in] number of samples to randomize, defaults to requestDataSize which uses the number of samples there are in the dataset
|
2014-10-15 09:39:19 +04:00
|
|
|
virtual void StartMinibatchLoop(size_t mbSize, size_t /*epoch*/, size_t /*requestedEpochSamples=requestDataSize*/)
|
2014-08-30 03:21:42 +04:00
|
|
|
{
|
|
|
|
m_mbSize = min(mbSize,m_recordCount);
|
|
|
|
}
|
|
|
|
|
|
|
|
// GetMinibatch - Get the next minibatch (features and labels)
|
|
|
|
// matrices - [in] a map with named matrix types (i.e. 'features', 'labels') mapped to the corresponing matrix,
|
|
|
|
// [out] each matrix resized if necessary containing data.
|
|
|
|
// returns - true if there are more minibatches, false if no more minibatchs remain
|
|
|
|
virtual bool GetMinibatch(std::map<std::wstring, Matrix<ElemType>*>& matrices)
|
|
|
|
{
|
|
|
|
// how many records are we reading this time
|
|
|
|
size_t recordCount = min(m_mbSize, m_recordCount-m_currentRecord);
|
|
|
|
|
|
|
|
// check to see if we are out of records in this current dataset
|
|
|
|
if (m_currentRecord >= m_recordCount)
|
|
|
|
return false;
|
|
|
|
|
|
|
|
// loop through all the input vectors to copy the data over
|
2014-11-05 00:47:58 +03:00
|
|
|
for (auto iter = m_inputs->begin(); iter != m_inputs->end(); ++iter)
|
2014-08-30 03:21:42 +04:00
|
|
|
{
|
|
|
|
// figure out the dimension of the data
|
|
|
|
std::wstring val = iter->first;
|
|
|
|
size_t rows = (*m_dimensions)[val];
|
2014-10-15 09:39:19 +04:00
|
|
|
//size_t count = rows*recordCount;
|
2014-08-30 03:21:42 +04:00
|
|
|
|
|
|
|
// find the output matrix we want to fill
|
2014-11-05 00:47:58 +03:00
|
|
|
auto iterIn = matrices.find(val);
|
2014-08-30 03:21:42 +04:00
|
|
|
|
|
|
|
// allocate the matrix if we don't have one yet
|
|
|
|
if (iterIn == matrices.end())
|
|
|
|
{
|
2014-10-30 20:33:51 +03:00
|
|
|
RuntimeError("No matrix data found for key '%ls', cannot continue", val.c_str());
|
2014-08-30 03:21:42 +04:00
|
|
|
}
|
|
|
|
|
|
|
|
Matrix<ElemType>* matrix = iterIn->second;
|
|
|
|
|
|
|
|
// resize to the proper size to hold the data
|
|
|
|
matrix->Resize(rows, recordCount);
|
|
|
|
|
|
|
|
// copy over the data
|
|
|
|
std::vector<ElemType>* data = iter->second;
|
2014-10-15 09:39:19 +04:00
|
|
|
//size_t = m_currentRecord*rows;
|
2014-08-30 03:21:42 +04:00
|
|
|
void* mat = &(*matrix)(0,0);
|
|
|
|
size_t matSize = matrix->GetNumElements()*sizeof(ElemType);
|
2015-06-28 09:33:06 +03:00
|
|
|
void* dataPtr = (void*)((ElemType*)data->data() + m_currentRecord*rows);
|
2014-08-30 03:21:42 +04:00
|
|
|
size_t dataSize = rows*recordCount*sizeof(ElemType);
|
|
|
|
memcpy_s(mat, matSize, dataPtr, dataSize);
|
|
|
|
}
|
|
|
|
|
|
|
|
// increment our record pointer
|
|
|
|
m_currentRecord += recordCount;
|
|
|
|
|
|
|
|
// return true if we returned any data whatsoever
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2015-09-19 03:35:07 +03:00
|
|
|
size_t GetNumParallelSequences() { return 1; }
|
2014-08-30 03:21:42 +04:00
|
|
|
|
2015-09-19 03:35:07 +03:00
|
|
|
void SetNumParallelSequences(const size_t ) {}
|
2015-04-09 07:20:12 +03:00
|
|
|
void SetSentenceSegBatch(std::vector<size_t> &sentenceEnd)
|
2014-11-07 06:24:05 +03:00
|
|
|
{
|
|
|
|
sentenceEnd.resize(m_switchFrame.size());
|
2015-04-09 07:20:12 +03:00
|
|
|
for (size_t i = 0; i < m_switchFrame.size(); i++)
|
2014-11-07 06:24:05 +03:00
|
|
|
{
|
|
|
|
sentenceEnd[i] = m_switchFrame[i];
|
|
|
|
}
|
|
|
|
}
|
2015-09-16 23:24:54 +03:00
|
|
|
void CopyMBLayoutTo(MBLayoutPtr pMBLayout)
|
2015-04-09 07:20:12 +03:00
|
|
|
{
|
2015-07-18 02:42:22 +03:00
|
|
|
assert(m_switchFrame.size() == 1);
|
2015-09-26 02:24:02 +03:00
|
|
|
pMBLayout->Init(1, m_mbSize, true/*sequential*/); // TODO: not sure if this is always sequential
|
2015-07-18 02:42:22 +03:00
|
|
|
|
|
|
|
if (m_switchFrame[0] < m_mbSize) /* there is a switch frame within the minibatch*/
|
|
|
|
{
|
2015-09-26 00:55:23 +03:00
|
|
|
pMBLayout->Set(0, m_switchFrame[0], MinibatchPackingFlags::SequenceStart);
|
2015-07-18 02:42:22 +03:00
|
|
|
if (m_switchFrame[0] > 0)
|
2015-09-26 00:55:23 +03:00
|
|
|
pMBLayout->SetWithoutOr(0, m_switchFrame[0] - 1, MinibatchPackingFlags::SequenceEnd); // TODO: can't we use Set()?
|
2015-07-18 02:42:22 +03:00
|
|
|
}
|
2015-04-09 07:20:12 +03:00
|
|
|
}
|
2015-07-08 23:52:44 +03:00
|
|
|
|
2014-11-07 06:24:05 +03:00
|
|
|
void GetSentenceBoundary(std::vector<size_t> boundaryInfo)
|
|
|
|
{
|
|
|
|
m_switchFrame.resize(boundaryInfo.size());
|
|
|
|
for (size_t i = 0; i < m_switchFrame.size(); i ++)
|
|
|
|
m_switchFrame[i] = boundaryInfo[i];
|
|
|
|
}
|
2015-04-09 07:20:12 +03:00
|
|
|
|
|
|
|
void SetRandomSeed(int) { NOT_IMPLEMENTED; }
|
|
|
|
|
2014-08-30 03:21:42 +04:00
|
|
|
// GetLabelMapping - Gets the label mapping from integer index to label type
|
|
|
|
// returns - a map from numeric datatype to native label type
|
2014-11-05 07:07:35 +03:00
|
|
|
virtual const std::map<typename EvalReader<ElemType>::LabelIdType, typename EvalReader<ElemType>::LabelType>& GetLabelMapping(const std::wstring& /*sectionName*/)
|
2014-08-30 03:21:42 +04:00
|
|
|
{
|
2014-11-05 07:07:35 +03:00
|
|
|
static std::map<typename EvalReader<ElemType>::LabelIdType, typename EvalReader<ElemType>::LabelType> labelMap;
|
2014-08-30 03:21:42 +04:00
|
|
|
return labelMap;
|
|
|
|
}
|
|
|
|
|
|
|
|
// SetLabelMapping - Sets the label mapping from integer index to label
|
|
|
|
// labelMapping - mapping table from label values to IDs (must be 0-n)
|
|
|
|
// note: for tasks with labels, the mapping table must be the same between a training run and a testing run
|
2014-11-05 07:07:35 +03:00
|
|
|
virtual void SetLabelMapping(const std::wstring& /*sectionName*/, const std::map<typename EvalReader<ElemType>::LabelIdType, typename EvalReader<ElemType>::LabelType>& /*labelMapping*/) {}
|
2014-08-30 03:21:42 +04:00
|
|
|
|
|
|
|
// GetData - Gets metadata from the specified section (into CPU memory)
|
|
|
|
// sectionName - section name to retrieve data from
|
|
|
|
// numRecords - number of records to read
|
|
|
|
// data - pointer to data buffer, if NULL, dataBufferSize will be set to size of required buffer to accomidate request
|
|
|
|
// dataBufferSize - [in] size of the databuffer in bytes
|
|
|
|
// [out] size of buffer filled with data
|
|
|
|
// recordStart - record to start reading from, defaults to zero (start of data)
|
|
|
|
// returns: true if data remains to be read, false if the end of data was reached
|
2014-10-15 09:39:19 +04:00
|
|
|
virtual bool GetData(const std::wstring& /*sectionName*/, size_t /*numRecords*/, void* /*data*/, size_t& /*dataBufferSize*/, size_t /*recordStart=0*/)
|
2014-08-30 03:21:42 +04:00
|
|
|
{
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
2014-10-15 09:39:19 +04:00
|
|
|
virtual bool DataEnd(EndDataType /*endDataType*/)
|
2014-08-30 03:21:42 +04:00
|
|
|
{
|
|
|
|
return m_currentRecord < m_recordCount;
|
|
|
|
}
|
2015-09-26 08:36:05 +03:00
|
|
|
|
|
|
|
virtual bool GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticesource::latticepair>> & /*latticeinput*/, vector<size_t> & /*uids*/,
|
|
|
|
vector<size_t> & /*boundaries*/, vector<size_t> &/*extrauttmap*/)
|
|
|
|
{
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
|
|
|
virtual bool GetHmmData(msra::asr::simplesenonehmm * /*hmm*/)
|
|
|
|
{
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
|
|
|
virtual void SetValidFrameInBatch(vector<size_t> &/*validFrame*/)
|
|
|
|
{
|
|
|
|
return;
|
|
|
|
}
|
2015-09-24 06:14:20 +03:00
|
|
|
|
2014-08-30 03:21:42 +04:00
|
|
|
};
|
|
|
|
|
2014-11-05 07:07:35 +03:00
|
|
|
}}}
|