// // // Copyright (c) Microsoft Corporation. All rights reserved. // // #pragma once #include "DataWriter.h" #include "SequenceParser.h" #include #define MAX_STRING 2048 namespace Microsoft { namespace MSR { namespace CNTK { template void DATAWRITER_API GetWriter(IDataWriter** pwriter) { *pwriter = new LMSequenceWriter(); } extern "C" DATAWRITER_API void GetWriterF(IDataWriter** pwriter) { GetWriter(pwriter); } extern "C" DATAWRITER_API void GetWriterD(IDataWriter** pwriter) { GetWriter(pwriter); } template class LMSequenceWriter : public IDataWriter { private: std::vector outputDims; map outputFiles; map outputFileIds; std::vector udims; int class_size; map>> class_words; map> word4idx; map> idx4word; map> idx4class; map> idx4cnt; int nwords; map mUnk; /// unk symbol int noise_sample_size; noiseSampler m_noiseSampler; map nBests; bool compare_val(const ElemType& first, const ElemType& second); void SaveToFile(std::wstring& outputFile, const Matrix& outputData, const map& idx2wrd, const int& nbest = 1); void ReadLabelInfo(const wstring & vocfile, map & word4idx, map& idx4word); public: ~LMSequenceWriter(){ Destroy(); } public: void GetSections(std::map& /*sections*/){} void SaveMapping(std::wstring saveId, const std::map& /*labelMapping*/){} public: virtual void Init(const ConfigParameters& writerConfig); virtual void Destroy(); virtual bool SaveData(size_t recordStart, const std::map& matrices, size_t numRecords, size_t datasetSize, size_t byteVariableSized); }; } } }