bug fix: InferMBLayoutFromInputsForStandardCase() must test isFinalValidationPass

This commit is contained in:
Frank Seide 2016-03-21 21:34:41 -07:00
Родитель 68770528bb
Коммит 76543e8708
12 изменённых файлов: 27 добавлений и 26 удалений

Просмотреть файл

@ -548,6 +548,7 @@ CNTK_SRC =\
$(SOURCEDIR)/CNTK/ModelEditLanguage.cpp \
$(SOURCEDIR)/CNTK/tests.cpp \
$(SOURCEDIR)/ComputationNetworkLib/ComputationNode.cpp \
$(SOURCEDIR)/ComputationNetworkLib/ComputationNodeScripting.cpp \
$(SOURCEDIR)/ComputationNetworkLib/ReshapingNodes.cpp \
$(SOURCEDIR)/ComputationNetworkLib/ComputationNetwork.cpp \
$(SOURCEDIR)/ComputationNetworkLib/ComputationNetworkEvaluation.cpp \
@ -561,7 +562,7 @@ CNTK_SRC =\
$(SOURCEDIR)/ActionsLib/EvalActions.cpp \
$(SOURCEDIR)/ActionsLib/OtherActions.cpp \
$(SOURCEDIR)/ActionsLib/SpecialPurposeActions.cpp \
$(SOURCEDIR)/ActionsLib/NetworkFactory.cpp \
$(SOURCEDIR)/ActionsLib/NetworkFactory.cpp \
$(SOURCEDIR)/ActionsLib/NetworkDescriptionLanguage.cpp \
$(SOURCEDIR)/ActionsLib/SimpleNetworkBuilder.cpp \
$(SOURCEDIR)/ActionsLib/NDLNetworkBuilder.cpp \

Просмотреть файл

@ -223,7 +223,7 @@ void DoWriteOutput(const ConfigParameters& config)
vector<wstring> outputNodeNamesVector;
auto net = GetModelFromConfig<ConfigParameters, ElemType>(config, outputNodeNamesVector);
let net = GetModelFromConfig<ConfigParameters, ElemType>(config, outputNodeNamesVector);
SimpleOutputWriter<ElemType> writer(net, 1);

Просмотреть файл

@ -81,7 +81,7 @@ void ComputationNode<ElemType>::Backprop(const FrameRange& fr, bool childrenInTh
// - with the exception of NULL layouts (e.g. TimesNode)
// - all layouts may be NULL (e.g. W' = W * Exp(Stabilizer))
// - if there are more than one different layouts involved, this function will fail
void ComputationNodeBase::InferMBLayoutFromInputsForStandardCase()
void ComputationNodeBase::InferMBLayoutFromInputsForStandardCase(bool isFinalValidationPass)
{
MBLayoutPtr pMBLayout; // start with NULL layout
for (auto child : m_inputs)
@ -92,7 +92,7 @@ void ComputationNodeBase::InferMBLayoutFromInputsForStandardCase()
;
else if (!pMBLayout) // first non-NULL layout: just copy it
pMBLayout = child->m_pMBLayout;
else if (pMBLayout != child->m_pMBLayout) // got a layout--compare whether it is the same
else if (pMBLayout != child->m_pMBLayout && isFinalValidationPass) // got a layout--compare whether it is the same
RuntimeError("InferMBLayoutFromInputsForStandardCase: Found inconsistent layout in %ls %ls operation, mismatch detected for child %ls %ls.",
NodeName().c_str(), OperationName().c_str(), child->NodeName().c_str(), child->OperationName().c_str());
}
@ -105,7 +105,7 @@ void ComputationNodeBase::ValidateUnaryMap(bool isFinalValidationPass)
{
assert(m_inputs.size() == 1);
ComputationNodeBase::Validate(isFinalValidationPass);
InferMBLayoutFromInputsForStandardCase();
InferMBLayoutFromInputsForStandardCase(isFinalValidationPass);
SetDims(Input(0));
}
@ -116,7 +116,7 @@ void ComputationNodeBase::ValidateBinaryZip(bool isFinalValidationPass, bool all
{
assert(m_inputs.size() == 2);
ComputationNodeBase::Validate(isFinalValidationPass);
InferMBLayoutFromInputsForStandardCase();
InferMBLayoutFromInputsForStandardCase(isFinalValidationPass);
ValidateInferBinaryInputDims();
@ -193,7 +193,7 @@ void ComputationNodeBase::ValidateInferBinaryInputDims()
assert(m_inputs.size() >= 2);
for (size_t index = 0; index < 2; index++)
{
auto in = Input(index);
auto in = Input( index);
auto other = Input(1 - index);
// borrow any unset dimension on one input from the other input
in->ValidateInferInputDimsFrom(other->GetSampleLayout());

Просмотреть файл

@ -635,7 +635,7 @@ protected:
void ValidateInferBinaryInputDims();
void ValidateBinaryZip(bool isFinalValidationPass, bool allowBroadcast);
void ValidateBinaryReduce(bool isFinalValidationPass);
void InferMBLayoutFromInputsForStandardCase();
void InferMBLayoutFromInputsForStandardCase(bool isFinalValidationPass);
virtual void ValidateInferInputDimsFrom(const TensorShape&) = 0; // (implemented by ComputationNode<ElemType>)
public:

Просмотреть файл

@ -201,7 +201,7 @@ public:
void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override
{
Base::Validate(isFinalValidationPass);
InferMBLayoutFromInputsForStandardCase();
InferMBLayoutFromInputsForStandardCase(isFinalValidationPass);
// get input and output tensor shape and interpret as image dimensions
auto inDims = ImageDimensions(GetInputSampleLayout(1), m_imageLayoutKind);
@ -423,7 +423,7 @@ public:
void Validate(bool isFinalValidationPass) override
{
Base::Validate(isFinalValidationPass);
InferMBLayoutFromInputsForStandardCase();
InferMBLayoutFromInputsForStandardCase(isFinalValidationPass);
// get input tensor shape and interpret as image dimensions
auto inDims = ImageDimensions(GetInputSampleLayout(0), m_imageLayoutKind);

Просмотреть файл

@ -318,7 +318,7 @@ public:
virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override
{
Base::Validate(isFinalValidationPass);
InferMBLayoutFromInputsForStandardCase();
InferMBLayoutFromInputsForStandardCase(isFinalValidationPass);
if (isFinalValidationPass)
if (!(Input(1)->GetSampleMatrixNumRows() == Input(2)->GetSampleMatrixNumRows() && // position dependent and pair scores have same number of labels

Просмотреть файл

@ -689,7 +689,7 @@ public:
virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override
{
Base::Validate(isFinalValidationPass);
InferMBLayoutFromInputsForStandardCase();
InferMBLayoutFromInputsForStandardCase(isFinalValidationPass);
if (isFinalValidationPass && !HasMBLayout())
InvalidArgument("%ls %ls operation can only operate on minibatches.", NodeName().c_str(), OperationName().c_str());

Просмотреть файл

@ -245,7 +245,7 @@ public:
Base::Validate(isFinalValidationPass);
if (isFinalValidationPass && Input(0)->HasMBLayout())
InvalidArgument("%ls %ls operation requires the first factor to not be minibatch data (must not have an MBLayout).", NodeName().c_str(), OperationName().c_str());
InferMBLayoutFromInputsForStandardCase();
InferMBLayoutFromInputsForStandardCase(isFinalValidationPass);
bool transpose = m_transpose; // (assigning to a non-const variable avoids a compiler warning C4127: conditional expression is constant)
@ -486,7 +486,7 @@ public:
virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override
{
Base::Validate(isFinalValidationPass);
InferMBLayoutFromInputsForStandardCase();
InferMBLayoutFromInputsForStandardCase(isFinalValidationPass);
size_t rows0 = Input(0)->GetAsMatrixNumRows();
size_t rows1 = Input(1)->HasMBLayout() ? Input(1)->GetSampleMatrixNumRows() : Input(1)->GetAsMatrixNumRows();
@ -651,7 +651,7 @@ public:
virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override
{
Base::Validate(isFinalValidationPass);
InferMBLayoutFromInputsForStandardCase();
InferMBLayoutFromInputsForStandardCase(isFinalValidationPass);
SetDims(TensorShape(1), Input(0)->HasMBLayout()); // each column is reduced to a scalar
}
@ -745,7 +745,7 @@ public:
{
assert(m_inputs.size() == 1);
ComputationNodeBase::Validate(isFinalValidationPass);
InferMBLayoutFromInputsForStandardCase();
InferMBLayoutFromInputsForStandardCase(isFinalValidationPass);
// input shape
auto shape = Input(0)->GetSampleLayout();
@ -842,7 +842,7 @@ public:
virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override
{
Base::Validate(isFinalValidationPass);
InferMBLayoutFromInputsForStandardCase();
InferMBLayoutFromInputsForStandardCase(isFinalValidationPass);
ValidateInferBinaryInputDims();
@ -963,7 +963,7 @@ public:
virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override
{
Base::Validate(isFinalValidationPass);
InferMBLayoutFromInputsForStandardCase();
InferMBLayoutFromInputsForStandardCase(isFinalValidationPass);
size_t rows0 = Input(0)->GetSampleMatrixNumRows();
size_t rows1 = Input(1)->GetSampleMatrixNumRows();
@ -1152,7 +1152,7 @@ public:
virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override
{
Base::Validate(isFinalValidationPass);
InferMBLayoutFromInputsForStandardCase();
InferMBLayoutFromInputsForStandardCase(isFinalValidationPass);
ValidateInferBinaryInputDims();

Просмотреть файл

@ -426,7 +426,7 @@ public:
virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override
{
Base::Validate(isFinalValidationPass);
InferMBLayoutFromInputsForStandardCase();
InferMBLayoutFromInputsForStandardCase(isFinalValidationPass);
Input(1)->ValidateInferInputDimsFrom(Input(0)->GetSampleLayout());
Input(2)->ValidateInferInputDimsFrom(Input(0)->GetSampleLayout());
@ -527,7 +527,7 @@ public:
virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override
{
Base::Validate(isFinalValidationPass);
InferMBLayoutFromInputsForStandardCase();
InferMBLayoutFromInputsForStandardCase(isFinalValidationPass);
Input(1)->ValidateInferInputDimsFrom(Input(0)->GetSampleLayout());
Input(2)->ValidateInferInputDimsFrom(Input(0)->GetSampleLayout());

Просмотреть файл

@ -314,7 +314,7 @@ public:
virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override
{
Base::Validate(isFinalValidationPass);
InferMBLayoutFromInputsForStandardCase();
InferMBLayoutFromInputsForStandardCase(isFinalValidationPass);
auto sampleLayout = Input(0)->GetSampleLayout();
if (isFinalValidationPass && sampleLayout[0] < m_startIndex + m_sliceHeight)
@ -421,7 +421,7 @@ public:
virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override
{
Base::Validate(isFinalValidationPass);
InferMBLayoutFromInputsForStandardCase();
InferMBLayoutFromInputsForStandardCase(isFinalValidationPass);
// we must fuse all tensor shapes
// All dimensions but the last must be the same. (In a future version, we should be able to stack along any given dimension.)
@ -528,7 +528,7 @@ public:
virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override
{
Base::Validate(isFinalValidationPass);
InferMBLayoutFromInputsForStandardCase();
InferMBLayoutFromInputsForStandardCase(isFinalValidationPass);
// the trailing dimension gets multiplied
// TODO: Or should we add an additional dimension?

Просмотреть файл

@ -315,7 +315,7 @@ public:
virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override
{
Base::Validate(isFinalValidationPass);
InferMBLayoutFromInputsForStandardCase();
InferMBLayoutFromInputsForStandardCase(isFinalValidationPass);
size_t rows[4];
for (int i = 0; i < 4; i++)

Просмотреть файл

@ -1747,7 +1747,7 @@ public:
void Validate(bool isFinalValidationPass) override
{
Base::Validate(isFinalValidationPass);
InferMBLayoutFromInputsForStandardCase();
InferMBLayoutFromInputsForStandardCase(isFinalValidationPass);
SetDims(Input(0));