CNTK/Source/ComputationNetworkLib/SpecialPurposeNodes.cpp

150 строки
6.2 KiB
C++

//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
#include "Basics.h"
#include "ComputationNode.h"
#include "SpecialPurposeNodes.h"
#include <string>
#include <vector>
#include <stdexcept>
#include <memory>
namespace Microsoft { namespace MSR { namespace CNTK {
// -----------------------------------------------------------------------
// Trace (node, say='', logFrequency=10, logFirst=10, logGradientToo=false, onlyUpToRow=100000000, onlyUpToT=100000000, format=[])
//
// Debugging aid to trace a node's value using WriteMinibatchWithFormatting().
// -----------------------------------------------------------------------
template <class ElemType>
TraceNode<ElemType>::TraceNode(const ScriptableObjects::IConfigRecordPtr configp) :
TraceNode(configp->Get(L"deviceId"), L"<placeholder>")
{
AttachInputsFromConfig(configp, this->GetExpectedNumInputs());
m_message = (const std::wstring&)configp->Get(L"say");
m_logFirst = configp->Get(L"logFirst");
m_logFrequency = configp->Get(L"logFrequency");
m_logGradientToo = configp->Get(L"logGradientToo");
m_formattingOptions = WriteFormattingOptions(*configp);
m_onlyUpToRow = configp->Get(L"onlyUpToRow");
m_onlyUpToT = configp->Get(L"onlyUpToT");
}
template <class ElemType>
/*virtual*/ void TraceNode<ElemType>::Save(File& fstream) const /*override*/
{
Base::Save(fstream);
fstream << m_message;
fstream << m_logFirst;
fstream << m_logFrequency;
fstream << m_logGradientToo;
m_formattingOptions.Save(fstream);
// BUGBUG: This serializes the pathname of the mapping file to disk. Not nice. But no better solution.
fstream << m_onlyUpToRow;
fstream << m_onlyUpToT;
}
template <class ElemType>
/*virtual*/ void TraceNode<ElemType>::Load(File& fstream, size_t modelVersion) /*override*/
{
Base::Load(fstream, modelVersion);
fstream >> m_message;
fstream >> m_logFirst;
fstream >> m_logFrequency;
fstream >> m_logGradientToo;
m_formattingOptions.Load(fstream, modelVersion);
fstream >> m_onlyUpToRow;
fstream >> m_onlyUpToT;
}
template <class ElemType>
/*virtual*/ void TraceNode<ElemType>::BeginForwardProp() /*override*/
{
Base::BeginForwardProp();
++m_numMBsRun;
}
template <class ElemType>
/*virtual*/ void TraceNode<ElemType>::ForwardProp(const FrameRange& fr) /*override*/
{
size_t rank = DetermineElementwiseTensorRank();
auto result = ValueTensorFor(rank, fr);
auto input = InputRef(0).ValueTensorFor(rank, fr);
result.AssignCopyOf(input);
// do the tracing
Log(fr, false/*means log value*/);
}
template <class ElemType>
/*virtual*/ void TraceNode<ElemType>::BackpropTo(const size_t inputIndex, const FrameRange& fr) /*override*/
{
assert(inputIndex == 0); inputIndex;
size_t rank = DetermineElementwiseTensorRank();
auto sliceOutputGrad = GradientTensorFor(rank, fr); // propagate from this one...
auto sliceInputGrad = InputRef(0).GradientTensorFor(rank, fr); // ...to this one
sliceInputGrad.AddCopyOf(sliceOutputGrad);
// do the tracing
if (m_logGradientToo)
Log(fr, true/*means log gradient*/);
}
// log value or gradient
template <class ElemType>
/*virtual*/ void TraceNode<ElemType>::Log(const FrameRange& fr, bool logGradientInstead) const
{
if (m_numMBsRun == 1)
{
const auto prologue = m_formattingOptions.Processed(NodeName(), m_formattingOptions.prologue, m_numMBsRun);
fprintf(stderr, "%s", prologue.c_str());
}
if (m_numMBsRun <= m_logFirst || (m_logFrequency && (m_numMBsRun-1) % m_logFrequency == 0))
{
char formatChar = !m_formattingOptions.isCategoryLabel ? 'f' : !m_formattingOptions.labelMappingFile.empty() ? 's' : 'u';
auto valueFormatString = "%" + m_formattingOptions.precisionFormat + formatChar; // format string used in fprintf() for formatting the values
const auto sequenceSeparator = m_formattingOptions.Processed(NodeName(), m_formattingOptions.sequenceSeparator, m_numMBsRun);
const auto sequencePrologue = m_formattingOptions.Processed(NodeName(), m_formattingOptions.sequencePrologue, m_numMBsRun);
const auto sequenceEpilogue = m_formattingOptions.Processed(NodeName(), m_formattingOptions.sequenceEpilogue, m_numMBsRun);
const auto elementSeparator = m_formattingOptions.Processed(NodeName(), m_formattingOptions.elementSeparator, m_numMBsRun);
const auto sampleSeparator = m_formattingOptions.Processed(NodeName(), m_formattingOptions.sampleSeparator, m_numMBsRun);
let timeRange = fr.GetTimeRange();
fprintf(stderr, "------- Trace["); // --- for better visual separability from actual content
if (fr.IsAllFrames())
;
else if (timeRange.second == timeRange.first + 1)
fprintf(stderr, "%d", (int)timeRange.first);
else if (timeRange.second > timeRange.first + 1)
fprintf(stderr, "%d..%d", (int)timeRange.first, (int)timeRange.second-1);
fprintf(stderr, "] %ls %s--> %s\n", m_message.c_str(), logGradientInstead ? "(gradient) " : "", InputRef(0).FormatOperationPrototype("").c_str());
InputRef(0).WriteMinibatchWithFormatting(stderr, fr, m_onlyUpToRow, m_onlyUpToT, m_formattingOptions.transpose, m_formattingOptions.isCategoryLabel, m_formattingOptions.isSparse, m_labelMapping,
sequenceSeparator, sequencePrologue, sequenceEpilogue, elementSeparator, sampleSeparator,
valueFormatString, logGradientInstead);
}
}
template <class ElemType>
/*virtual*/ void TraceNode<ElemType>::Validate(bool isFinalValidationPass) // override
{
ValidateUnaryMap(isFinalValidationPass);
if (isFinalValidationPass)
{
if (m_labelMapping.empty() && (m_formattingOptions.isCategoryLabel || m_formattingOptions.isSparse) && !m_formattingOptions.labelMappingFile.empty())
File::LoadLabelFile(m_formattingOptions.labelMappingFile, m_labelMapping);
}
m_numMBsRun = 0;
}
template class TraceNode<float>;
template class TraceNode<double>;
template class TraceNode<half>;
}}}