//
//
// Copyright (c) Microsoft Corporation. All rights reserved.
//
//
#pragma once
#define DATAWRITER_LOCAL
#include "DataWriter.h"
namespace Microsoft { namespace MSR { namespace CNTK {
// Evaluation Writer class
// interface to pass to evaluation DLL
template
class EvalWriter : public IDataWriter
{
typedef typename IDataWriter::LabelType LabelType;
typedef typename IDataWriter::LabelIdType LabelIdType;
private:
std::map*>* m_outputs; // our output data
std::map* m_dimensions; // the number of rows for the output data
size_t m_recordCount; // count of records in this data
size_t m_currentRecord; // next record number to read
public:
// Method to setup the data for the reader
void SetData(std::map*>* outputs, std::map* dimensions)
{
m_outputs = outputs;
m_dimensions = dimensions;
m_currentRecord = 0;
m_recordCount = 0;
for (auto iter = outputs->begin(); iter != outputs->end(); ++iter)
{
// figure out the dimension of the data
const std::wstring& val = iter->first;
size_t count = (*outputs)[val]->size();
if (dimensions->find(val) == dimensions->end())
{
RuntimeError("Output %ls not found in CNTK model.", val.c_str());
}
size_t rows = (*dimensions)[val];
size_t recordCount = count/rows;
if (m_recordCount != 0)
{
// record count must be the same for all the data
if (recordCount != m_recordCount)
RuntimeError("Record Count of %ls (%lux%lu) does not match the record count of previous entries (%lu).", val.c_str(), rows, recordCount, m_recordCount);
}
else
{
m_recordCount = recordCount;
}
}
}
virtual void Init(const ConfigParameters& /*config*/)
{
}
// Destroy - cleanup and remove this class
// NOTE: this destroys the object, and it can't be used past this point
virtual void Destroy()
{
delete this;
}
// EvalWriter Constructor
// config - [in] configuration parameters for the datareader
EvalWriter(const ConfigParameters& config)
{
m_recordCount = m_currentRecord = 0;
Init(config);
}
// Destructor - free up the matrix values we allocated
virtual ~EvalWriter()
{
}
virtual void GetSections(std::map& /*sections*/)
{
assert(false);
NOT_IMPLEMENTED;
}
virtual bool SaveData(size_t /*recordStart*/, const std::map& matrices, size_t numRecords, size_t /*datasetSize*/, size_t /*byteVariableSized*/)
{
// loop through all the output vectors to copy the data over
for (auto iter = m_outputs->begin(); iter != m_outputs->end(); ++iter)
{
// figure out the dimension of the data
std::wstring val = iter->first;
size_t rows = (*m_dimensions)[val];
//size_t count = rows*numRecords;
// find the output matrix we want to fill
const std::map::const_iterator iterIn = matrices.find(val);
// allocate the matrix if we don't have one yet
if (iterIn == matrices.end())
{
RuntimeError("No matrix data found for key '%ls', cannot continue", val.c_str());
}
Matrix* matrix = (Matrix*)iterIn->second;
// copy over the data
std::vector* data = iter->second;
size_t index = m_currentRecord*rows;
size_t numberToCopy = rows*numRecords;
data->resize(index+numberToCopy);
void* dataPtr = (void*)((ElemType*)data->data() + index);
size_t dataSize = numberToCopy*sizeof(ElemType);
void* mat = &(*matrix)(0,0);
size_t matSize = matrix->GetNumElements()*sizeof(ElemType);
memcpy_s(dataPtr, dataSize, mat, matSize);
}
// increment our record pointer
m_currentRecord += numRecords;
// return the "done with all records" value
return (m_currentRecord >= m_recordCount);
}
virtual void SaveMapping(std::wstring saveId, const std::map::LabelIdType, typename EvalWriter::LabelType>& /*labelMapping*/) {};
};
}}}