Fix KaldiReader for the new interface (SetValue).

This commit is contained in:
Yu 2015-10-28 18:36:52 -04:00
Родитель 8228164c2e
Коммит 661942127e
2 изменённых файлов: 18 добавлений и 8 удалений

Просмотреть файл

@ -1106,6 +1106,14 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// We initialize the sentence boundary information before we process
// the utterances.
m_pMBLayout->Init(m_numberOfuttsPerMinibatch, m_currentMBSize, !m_framemode);
for (size_t i = 0; i < m_numberOfuttsPerMinibatch; i++)
{
for (size_t j = 0; j < m_currentMBSize; j++)
{
m_pMBLayout->SetWithoutOr(i, j, MinibatchPackingFlags::None);
}
}
// Iterates over utterances. m_numberOfuttsPerMinibatch = 1 is a
// special case.
for (size_t i = 0; i < m_numberOfuttsPerMinibatch; i++)
@ -1412,6 +1420,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
assert(id < m_minibatchBuffer[index].features.size());
data.SetValue(dim,
m_minibatchBuffer[index].features[id].size() / dim,
data.GetDeviceId(),
m_minibatchBuffer[index].features[id].data(),
matrixFlagNormal);
}
@ -1422,6 +1431,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
assert(id < m_minibatchBuffer[index].labels.size());
data.SetValue(dim,
m_minibatchBuffer[index].labels[id].size() / dim,
data.GetDeviceId(),
m_minibatchBuffer[index].labels[id].data(),
matrixFlagNormal);
}
@ -1488,14 +1498,14 @@ namespace Microsoft { namespace MSR { namespace CNTK {
size_t id = m_featureNameToIdMap.at(iter->first);
size_t dim = m_featureNameToDimMap.at(iter->first);
assert(id < featureBuffer.size());
data.SetValue(dim, size, featureBuffer[id] , matrixFlagNormal);
data.SetValue(dim, size, data.GetDeviceId(), featureBuffer[id] , matrixFlagNormal);
}
else if (m_nameToTypeMap.at(iter->first) == InputOutputTypes::category)
{
size_t id = m_labelNameToIdMap.at(iter->first);
size_t dim = m_labelNameToDimMap.at(iter->first);
assert(id < labelBuffer.size());
data.SetValue(dim, size, labelBuffer[id], matrixFlagNormal);
data.SetValue(dim, size, data.GetDeviceId(), labelBuffer[id], matrixFlagNormal);
}
else if (m_doMinibatchBuffering)
{
@ -1674,7 +1684,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
}
}
}
data.SetValue(feat.rows(), feat.cols(), m_featuresBufferMultiIO[id],matrixFlagNormal);
data.SetValue(feat.rows(), feat.cols(), data.GetDeviceId(), m_featuresBufferMultiIO[id],matrixFlagNormal);
}
else
{ // Resizes other inputs so they won't affect actual minibatch size.

Просмотреть файл

@ -851,7 +851,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
}
}
}
data.SetValue(feat.rows(), feat.cols(), m_featuresBufferMultiIO[id],matrixFlagNormal);
data.SetValue(feat.rows(), feat.cols(), data.GetDeviceId(), m_featuresBufferMultiIO[id],matrixFlagNormal);
}
}
else if (m_nameToTypeMap[iter->first] == InputOutputTypes::category)
@ -919,7 +919,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
}
data.SetValue(dim,uids.size(),m_labelsBufferMultiIO[id],matrixFlagNormal);
data.SetValue(dim,uids.size(),data.GetDeviceId(), m_labelsBufferMultiIO[id],matrixFlagNormal);
}
}
@ -1190,13 +1190,13 @@ namespace Microsoft { namespace MSR { namespace CNTK {
{
id = m_featureNameToIdMap[iter->first];
dim = m_featureNameToDimMap[iter->first];
data.SetValue(dim, m_mbSize*m_numberOfuttsPerMinibatch, m_featuresBufferMultiIO[id],matrixFlagNormal);
data.SetValue(dim, m_mbSize*m_numberOfuttsPerMinibatch, data.GetDeviceId(), m_featuresBufferMultiIO[id],matrixFlagNormal);
}
else if (m_nameToTypeMap[iter->first] == InputOutputTypes::category)
{
id = m_labelNameToIdMap[iter->first];
dim = m_labelNameToDimMap[iter->first];
data.SetValue(dim, m_mbSize*m_numberOfuttsPerMinibatch, m_labelsBufferMultiIO[id],matrixFlagNormal);
data.SetValue(dim, m_mbSize*m_numberOfuttsPerMinibatch, data.GetDeviceId(), m_labelsBufferMultiIO[id],matrixFlagNormal);
}
}
skip=false;
@ -1317,7 +1317,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
}
}
}
data.SetValue(feat.rows(), feat.cols(), m_featuresBufferMultiIO[id],matrixFlagNormal);
data.SetValue(feat.rows(), feat.cols(), data.GetDeviceId(), m_featuresBufferMultiIO[id],matrixFlagNormal);
}
}
return true;