199 строки
7.2 KiB
C++
199 строки
7.2 KiB
C++
//
|
|
// <copyright file="SequenceWriter.cpp" company="Microsoft">
|
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// </copyright>
|
|
//
|
|
|
|
//
|
|
|
|
#include "stdafx.h"
|
|
#include <objbase.h>
|
|
#include "Basics.h"
|
|
#include <fstream>
|
|
#include <algorithm>
|
|
|
|
#define DATAWRITER_EXPORTS // creating the exports here
|
|
#include "DataWriter.h"
|
|
#include "SequenceReader.h"
|
|
#include "SequenceWriter.h"
|
|
#include "commandArgUtil.h"
|
|
#ifdef LEAKDETECT
|
|
#include <vld.h> // for memory leak detection
|
|
#endif
|
|
|
|
namespace Microsoft {
|
|
namespace MSR {
|
|
namespace CNTK {
|
|
|
|
// Create a Data Writer
|
|
//DATAWRITER_API IDataWriter* DataWriterFactory(void)
|
|
|
|
|
|
// comparison, not case sensitive.
|
|
template<class ElemType>
|
|
bool LMSequenceWriter<ElemType>::compare_val(const ElemType& first, const ElemType& second)
|
|
{
|
|
return (first < second);
|
|
}
|
|
|
|
template<class ElemType>
|
|
void LMSequenceWriter<ElemType>::Init(const ConfigParameters& writerConfig)
|
|
{
|
|
udims.clear();
|
|
|
|
ConfigArray outputNames = writerConfig("outputNodeNames", "");
|
|
if (outputNames.size()<1)
|
|
RuntimeError("writer needs at least one outputNodeName specified in config");
|
|
|
|
foreach_index(i, outputNames) // inputNames should map to node names
|
|
{
|
|
ConfigParameters thisOutput = writerConfig(outputNames[i]);
|
|
outputFiles[outputNames[i]] = thisOutput("file");
|
|
int iN = thisOutput("nbest", "1");
|
|
nBests[outputNames[i]] = iN;
|
|
wstring fname = thisOutput("token");
|
|
/// read unk sybol
|
|
mUnk[outputNames[i]] = writerConfig("unk", "<unk>");
|
|
|
|
SequenceReader<ElemType>::ReadClassInfo(fname, class_size,
|
|
word4idx[outputNames[i]],
|
|
idx4word[outputNames[i]],
|
|
idx4class[outputNames[i]],
|
|
idx4cnt[outputNames[i]],
|
|
0,
|
|
mUnk[outputNames[i]],
|
|
m_noiseSampler,
|
|
false);
|
|
size_t dim = idx4word[outputNames[i]].size();
|
|
udims.push_back(dim);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
template<class ElemType>
|
|
void LMSequenceWriter<ElemType>::ReadLabelInfo(const wstring & vocfile,
|
|
map<string, int> & word4idx,
|
|
map<int, string>& idx4word)
|
|
{
|
|
char strFileName[MAX_STRING];
|
|
char stmp[MAX_STRING];
|
|
string strtmp;
|
|
size_t sz;
|
|
int b;
|
|
|
|
wcstombs_s(&sz, strFileName, 2048, vocfile.c_str(), vocfile.length());
|
|
|
|
FILE * vin;
|
|
vin = fopen(strFileName, "rt");
|
|
|
|
if (vin == nullptr)
|
|
{
|
|
RuntimeError("cannot open word class file");
|
|
}
|
|
b = 0;
|
|
while (!feof(vin)){
|
|
fscanf_s(vin, "%s\n", stmp, _countof(stmp));
|
|
word4idx[stmp] = b;
|
|
idx4word[b++] = stmp;
|
|
}
|
|
fclose(vin);
|
|
|
|
}
|
|
|
|
template<class ElemType>
|
|
void LMSequenceWriter<ElemType>::Destroy()
|
|
{
|
|
for (auto ptr = outputFileIds.begin(); ptr != outputFileIds.end(); ptr++)
|
|
{
|
|
fclose(ptr->second);
|
|
}
|
|
}
|
|
|
|
template<class ElemType>
|
|
bool LMSequenceWriter<ElemType>::SaveData(size_t /*recordStart*/, const std::map<std::wstring, void*, nocase_compare>& matrices, size_t /*numRecords*/, size_t /*datasetSize*/, size_t /*byteVariableSized*/)
|
|
{
|
|
|
|
for (auto iter = matrices.begin(); iter != matrices.end(); iter++)
|
|
{
|
|
string outputName = ws2s(iter->first);
|
|
Matrix<ElemType>& outputData = *(static_cast<Matrix<ElemType>*>(iter->second));
|
|
wstring outFile = outputFiles[s2ws(outputName)];
|
|
|
|
SaveToFile(outFile, outputData, idx4word[iter->first], nBests[outputName]);
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
template<class ElemType>
|
|
void LMSequenceWriter<ElemType>::SaveToFile(std::wstring& outputFile, const Matrix<ElemType>& outputData, const map<int, string>& idx2wrd, const int& nbest)
|
|
{
|
|
size_t nT = outputData.GetNumCols();
|
|
size_t nD = min(idx2wrd.size(), outputData.GetNumRows());
|
|
FILE *fp = nullptr;
|
|
vector<pair<size_t, ElemType>> lv;
|
|
|
|
auto NbestComparator = [](const pair<size_t, ElemType>& lv, const pair<size_t, ElemType>& rv){return lv.second > rv.second; };
|
|
|
|
if (outputFileIds.find(outputFile) == outputFileIds.end())
|
|
{
|
|
FILE* ofs;
|
|
msra::files::make_intermediate_dirs(outputFile);
|
|
string str(outputFile.begin(), outputFile.end());
|
|
ofs = fopen(str.c_str(), "wt");
|
|
if (ofs == nullptr)
|
|
RuntimeError("Cannot open %s for writing", str.c_str());
|
|
outputFileIds[outputFile] = ofs;
|
|
fp = ofs;
|
|
}
|
|
else
|
|
fp = outputFileIds[outputFile];
|
|
|
|
for (int j = 0; j< nT; j++)
|
|
{
|
|
int imax = 0;
|
|
ElemType fmax = outputData(imax, j);
|
|
lv.clear();
|
|
if (nbest > 1) lv.push_back(pair<size_t, ElemType>(0, fmax));
|
|
for (int i = 1; i<nD; i++)
|
|
{
|
|
if (nbest > 1) lv.push_back(pair<size_t, ElemType>(i, outputData(i, j)));
|
|
if (outputData(i, j) > fmax)
|
|
{
|
|
fmax = outputData(i, j);
|
|
imax = i;
|
|
}
|
|
}
|
|
if (nbest > 1) sort(lv.begin(), lv.end(), NbestComparator);
|
|
for (int i = 0; i < nbest; i++)
|
|
{
|
|
if (nbest > 1)
|
|
{
|
|
if (lv[i].second != 0)
|
|
{
|
|
int idx = (int)lv[i].first;
|
|
string sRes = idx2wrd.find(idx)->second;
|
|
fprintf(fp, "%s ", sRes.c_str());
|
|
}
|
|
}
|
|
else
|
|
{
|
|
string sRes = idx2wrd.find(imax)->second;
|
|
fprintf(fp, "%s ", sRes.c_str());
|
|
fprintf(stderr, "%s ", sRes.c_str());
|
|
}
|
|
}
|
|
}
|
|
fprintf(fp, "\n");
|
|
fprintf(stderr, "\n");
|
|
}
|
|
|
|
|
|
template class LMSequenceWriter<float>;
|
|
template class LMSequenceWriter<double>;
|
|
|
|
}
|
|
}
|
|
}
|