CNTK/Source/EvalDll/EvalReader.h

263 строки
10 KiB
C++

//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
#pragma once
#define DATAREADER_LOCAL
#include "DataReader.h"
namespace Microsoft { namespace MSR { namespace CNTK {
// Evaluation Reader class
// interface to pass to evaluation DLL
template <class ElemType>
class EvalReader : public DataReaderBase
{
std::map<std::wstring, std::vector<ElemType>*>* m_inputs; // our input data
std::map<std::wstring, size_t>* m_dimensions; // the number of rows for the input data
size_t m_recordCount; // count of records in this data
size_t m_currentRecord; // next record number to read
size_t m_mbSize;
vector<size_t> m_switchFrame;
size_t m_oldSig;
public:
// Method to setup the data for the reader
void SetData(std::map<std::wstring, std::vector<ElemType>*>* inputs, std::map<std::wstring, size_t>* dimensions)
{
m_inputs = inputs;
m_dimensions = dimensions;
m_currentRecord = 0;
m_recordCount = 0;
for (auto iter = inputs->begin(); iter != inputs->end(); ++iter)
{
// figure out the dimension of the data
const std::wstring& val = iter->first;
size_t count = (*inputs)[val]->size();
size_t rows = (*dimensions)[val];
size_t recordCount = count / rows;
if (m_recordCount != 0)
{
// record count must be the same for all the data
if (recordCount != m_recordCount)
RuntimeError("Record Count of %ls (%lux%lu) does not match the record count of previous entries (%lu).", val.c_str(), rows, recordCount, m_recordCount);
}
else
{
m_recordCount = recordCount;
}
}
}
void SetBoundary(size_t newSig)
{
if (m_switchFrame.size() == 0)
{
m_oldSig = newSig;
m_switchFrame.assign(1, 0);
}
else
{
if (m_oldSig == newSig)
{
m_switchFrame[0] = m_mbSize + 8888; // TODO: WTF??
}
else
{
m_switchFrame[0] = 0;
m_oldSig = newSig;
}
}
}
virtual void Init(const ConfigParameters& /*config*/) override
{
}
virtual void Init(const ScriptableObjects::IConfigRecord& /*config*/) override
{
}
// Destroy - cleanup and remove this class
// NOTE: this destroys the object, and it can't be used past this point
virtual void Destroy()
{
delete this;
}
// EvalReader Constructor
// config - [in] configuration parameters for the datareader
template <class ConfigRecordType>
EvalReader(const ConfigRecordType& config)
{
m_recordCount = m_currentRecord = 0;
Init(config);
}
// Destructor - free up the matrix values we allocated
virtual ~EvalReader()
{
}
// StartMinibatchLoop - Startup a minibatch loop
// mbSize - [in] size of the minibatch (number of frames, etc.)
// epoch - [in] epoch number for this loop
// requestedEpochSamples - [in] number of samples to randomize, defaults to requestDataSize which uses the number of samples there are in the dataset
virtual void StartMinibatchLoop(size_t mbSize, size_t /*epoch*/, size_t /*requestedEpochSamples=requestDataSize*/)
{
m_mbSize = min(mbSize, m_recordCount);
}
// TryGetMinibatch - Get the next minibatch (features and labels)
// matrices - [in] a map with named matrix types (i.e. 'features', 'labels') mapped to the corresponding matrix,
// [out] each matrix resized if necessary containing data.
// returns - true if there are more minibatches, false if no more minibatchs remain
virtual bool TryGetMinibatch(StreamMinibatchInputs& matrices)
{
// how many records are we reading this time
size_t recordCount = min(m_mbSize, m_recordCount - m_currentRecord);
// check to see if we are out of records in this current dataset
if (m_currentRecord >= m_recordCount)
return false;
// loop through all the input vectors to copy the data over
for (auto iter = m_inputs->begin(); iter != m_inputs->end(); ++iter)
{
// figure out the dimension of the data
const auto& name = iter->first;
size_t rows = (*m_dimensions)[name];
// size_t count = rows*recordCount;
// find the output matrix we want to fill
if (!matrices.HasInput(name))
RuntimeError("No matrix data found for key '%ls'.", name.c_str());
// allocate the matrix if we don't have one yet
auto& matrix = matrices.GetInputMatrix<ElemType>(name);
// copy over the data
std::vector<ElemType>* data = iter->second;
ElemType* dataPtr = data->data() + (m_currentRecord * rows);
matrix.SetValue(rows, recordCount, matrix.GetDeviceId(), dataPtr, matrixFlagNormal);
}
// increment our record pointer
m_currentRecord += recordCount;
// return true if we returned any data whatsoever
return true;
}
size_t GetNumParallelSequencesForFixingBPTTMode()
{
return 1;
}
void SetNumParallelSequences(const size_t)
{
}
void SetSentenceSegBatch(std::vector<size_t>& sentenceEnd)
{
sentenceEnd.resize(m_switchFrame.size());
for (size_t i = 0; i < m_switchFrame.size(); i++)
{
sentenceEnd[i] = m_switchFrame[i];
}
}
void CopyMBLayoutTo(MBLayoutPtr pMBLayout)
{
assert(m_switchFrame.size() == 1);
pMBLayout->Init(1, m_mbSize);
// BUGBUG: The following code is somewhat broken in that the structure of this module only keeps track of new sentence starts,
// but not of ends. But end markers are now required by the MBLayout. So we must fake the end markers.
// That will fail if the previous sentence end fell on the boundary; then we will miss the end flag.
// This still works for a left-to-right model since for eval we only really look at the start flag.
// So we get lucky, sort of. Not nice.
// The correct solution is to rewrite this entire module to be more direct; no Reader needed, we can call ForwardProp() directly.
// BUGBUG: The module also does not keep track of the actual start in the past. So we fake the start, too.
// There are boundary cases where this will be incorrect for models with a delay of >1 step.
if (m_switchFrame[0] < m_mbSize) /* there is a switch frame within the minibatch */
{
// finish the current sequence
if (m_switchFrame[0] > 0) // BUGBUG: gonna miss the previous end flag if starting on frame [0], see above.
pMBLayout->AddSequence(0, 0, -1, m_switchFrame[0] - 1);
// start the new sequence
// We use a fake end of 1 frame beyond the actual end of the minibatch.
pMBLayout->AddSequence(0, 0, m_switchFrame[0], m_mbSize + 1);
// pMBLayout->Set(0, m_switchFrame[0], MinibatchPackingFlags::SequenceStart);
// if (m_switchFrame[0] > 0)
// pMBLayout->Set(0, m_switchFrame[0] - 1, MinibatchPackingFlags::SequenceEnd); // TODO: can't we use Set()?
}
else // all frames in this MB belong to the same utterance
{
// no boundary inide the MB: fake a sequence that spans 1 frame on each side. BUGBUG: That's wrong for delays of > 1 step, see above.
pMBLayout->AddSequence(0, 0, -1, m_mbSize + 1); // BUGBUG: gonna miss the end flag if it ends at end of this MB, see above
}
}
void GetSentenceBoundary(std::vector<size_t> boundaryInfo)
{
m_switchFrame.resize(boundaryInfo.size());
for (size_t i = 0; i < m_switchFrame.size(); i++)
m_switchFrame[i] = boundaryInfo[i];
}
void SetRandomSeed(int)
{
NOT_IMPLEMENTED;
}
// GetLabelMapping - Gets the label mapping from integer index to label type
// returns - a map from numeric datatype to native label type
virtual const std::map<typename EvalReader<ElemType>::LabelIdType, typename EvalReader<ElemType>::LabelType>& GetLabelMapping(const std::wstring& /*sectionName*/)
{
static std::map<typename EvalReader<ElemType>::LabelIdType, typename EvalReader<ElemType>::LabelType> labelMap;
return labelMap;
}
// SetLabelMapping - Sets the label mapping from integer index to label
// labelMapping - mapping table from label values to IDs (must be 0-n)
// note: for tasks with labels, the mapping table must be the same between a training run and a testing run
virtual void SetLabelMapping(const std::wstring& /*sectionName*/, const std::map<typename EvalReader<ElemType>::LabelIdType, typename EvalReader<ElemType>::LabelType>& /*labelMapping*/)
{
}
// GetData - Gets metadata from the specified section (into CPU memory)
// sectionName - section name to retrieve data from
// numRecords - number of records to read
// data - pointer to data buffer, if NULL, dataBufferSize will be set to size of required buffer to accomidate request
// dataBufferSize - [in] size of the databuffer in bytes
// [out] size of buffer filled with data
// recordStart - record to start reading from, defaults to zero (start of data)
// returns: true if data remains to be read, false if the end of data was reached
virtual bool GetData(const std::wstring& /*sectionName*/, size_t /*numRecords*/, void* /*data*/, size_t& /*dataBufferSize*/, size_t /*recordStart=0*/)
{
return false;
}
virtual bool DataEnd()
{
return m_currentRecord < m_recordCount;
}
virtual bool GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& /*latticeinput*/, vector<size_t>& /*uids*/,
vector<size_t>& /*boundaries*/, vector<size_t>& /*extrauttmap*/)
{
return true;
}
virtual bool GetHmmData(msra::asr::simplesenonehmm* /*hmm*/)
{
return true;
}
virtual void SetValidFrameInBatch(vector<size_t>& /*validFrame*/)
{
return;
}
};
} } }