//
//
// Copyright (c) Microsoft Corporation. All rights reserved.
//
//
#pragma once
#include "DataWriter.h"
#include "SequenceParser.h"
#include
#define MAX_STRING 2048
namespace Microsoft {
namespace MSR {
namespace CNTK {
template
void DATAWRITER_API GetWriter(IDataWriter** pwriter)
{
*pwriter = new LMSequenceWriter();
}
extern "C" DATAWRITER_API void GetWriterF(IDataWriter** pwriter)
{
GetWriter(pwriter);
}
extern "C" DATAWRITER_API void GetWriterD(IDataWriter** pwriter)
{
GetWriter(pwriter);
}
template
class LMSequenceWriter : public IDataWriter
{
private:
std::vector outputDims;
map outputFiles;
map outputFileIds;
std::vector udims;
int class_size;
map>> class_words;
map> word4idx;
map> idx4word;
map> idx4class;
map> idx4cnt;
int nwords;
map mUnk; /// unk symbol
int noise_sample_size;
noiseSampler m_noiseSampler;
map nBests;
bool compare_val(const ElemType& first, const ElemType& second);
void SaveToFile(std::wstring& outputFile, const Matrix& outputData, const map& idx2wrd, const int& nbest = 1);
void ReadLabelInfo(const wstring & vocfile,
map & word4idx,
map& idx4word);
public:
~LMSequenceWriter(){
Destroy();
}
public:
void GetSections(std::map& /*sections*/){}
void SaveMapping(std::wstring saveId, const std::map& /*labelMapping*/){}
public:
virtual void Init(const ConfigParameters& writerConfig);
virtual void Destroy();
virtual bool SaveData(size_t recordStart, const std::map& matrices, size_t numRecords, size_t datasetSize, size_t byteVariableSized);
};
}
}
}