CNTK v2 library: Add the ability to broadcast an input without MBLayout along a specified layout, in the ReconcileDynamicAxis node

This commit is contained in:
Amit Agarwal 2017-02-05 12:38:27 -08:00
Родитель 4f4b5fb52f
Коммит 9c8884eaa0
4 изменённых файлов: 36 добавлений и 25 удалений

Просмотреть файл

@ -231,6 +231,7 @@ namespace CNTK
CNTK_API FunctionPtr Scatter(const Variable& operand, const Variable& condition, const std::pair<size_t, int>& newDerivedSequenceAxisScalingAndAdditiveFactor, const std::wstring& name = L"");
CNTK_API FunctionPtr Slice(const Variable& operand, const Axis& axis, int beginIndex, int endIndex, const std::wstring& name = L"");
CNTK_API FunctionPtr ReduceElements(const Variable& operand, const std::wstring& reductionOpName, const Axis& axis, const std::wstring& name = L"");
CNTK_API FunctionPtr ReconcileDynamicAxes(const Variable& operand, const Variable& axesAsOperand, const std::wstring& name = L"");
// This is meant for debugging purposes only and is very likely to be deprecated in the future.
CNTK_API void SaveAsLegacyModel(const FunctionPtr& rootFunction, const std::wstring& modelFile);

Просмотреть файл

@ -1479,20 +1479,7 @@ namespace CNTK
FunctionPtr ZeroesWithDynamicAxesLike(const Variable& operand)
{
if (operand.IsSparse())
{
if (operand.Shape().Rank() > 1)
LogicError("Internal::ZeroesWithDynamicAxesLike: Currently only 1D sparse inputs are supported!");
// TODO: A matrix multiplication is too expensive for something like this
// Replace this with a cheaper operation.
return Times(Constant({ 1, operand.Shape()[0] }, operand.GetDataType(), 0.0), operand);
}
else
{
auto reduceAllStaticAxesFunc = Internal::ReduceElements(operand, PrimitiveFunction::InternalSumReductionOpName, Axis::AllStaticAxes());
return Minus(reduceAllStaticAxesFunc, reduceAllStaticAxesFunc);
}
return Internal::ReconcileDynamicAxes(Constant::Scalar(0.0f), operand);
}
FunctionPtr Where(const Variable& condition, const std::pair<size_t, int>& newDerivedSequenceAxisScalingAndAdditiveFactor, const std::wstring& name)
@ -1548,5 +1535,11 @@ namespace CNTK
LogicError("CNTK::ReduceElements: Invalid axis argument provided. To reduce a sequence along its ordered dynamic axis use Sequence::ReduceElements.");
}
FunctionPtr ReconcileDynamicAxes(const Variable& operand, const Variable& axesAsOperand, const std::wstring& name)
{
return BinaryOp(PrimitiveOpType::ReconcileDynamicAxis, operand, axesAsOperand, Dictionary(), name);
}
}
}

Просмотреть файл

@ -670,8 +670,6 @@ namespace CNTK
assert(m_inputs.size() == 2);
auto operand = m_inputs[0];
auto layout = m_inputs[1];
if (operand.DynamicAxes().empty())
InvalidArgument("ReconcileDynamicAxis: input must have at least one dynamic axis");
if (layout.DynamicAxes().empty())
InvalidArgument("ReconcileDynamicAxis: layout must have at least one dynamic axis");
outputShape = operand.Shape();

Просмотреть файл

@ -246,7 +246,8 @@ private:
// -----------------------------------------------------------------------
// ReconcileDynamicAxis (dataInput, layoutInput)
// This node copies data from 'dataInput' while it propagates the minibatch-layout information from 'layoutInput'.
// It does perform a runtime check to enforce that the layout of 'dataInput' is compatible (identical content) to that of 'layoutInput'.
// If 'dataInput' does not have a MBLayout, it broadcasts the dataInput along the MBLayout of the 'layoutInput'.
// It does perform a runtime check to enforce that the layout of 'dataInput' is compatible to that of 'layoutInput'.
// This node is meant to be used from BrainScript macros that bracket expand/reduce pairs of nodes. It is not meant to really be used directly.
// TODO: What to do with sequence-boundary flags?
// -----------------------------------------------------------------------
@ -268,22 +269,39 @@ public:
{
// enforce compatibility of 'dataInput' with 'layoutInput'
// TODO: how to deal with boundary flags?
if (*m_pMBLayout != *InputRef(0).GetMBLayout()) // this does a deep value-level comparison
if (InputRef(0).GetMBLayout() && (*m_pMBLayout != *InputRef(0).GetMBLayout())) // this does a deep value-level comparison
InvalidArgument("%ls %ls operation discovered that %ls %ls operation produced an MB layout that is incompatible with that of %ls %ls.",
NodeName().c_str(), OperationName().c_str(),
InputRef(0).NodeName().c_str(), InputRef(0).OperationName().c_str(),
InputRef(1).NodeName().c_str(), InputRef(1).OperationName().c_str());
// copy the data from 'dataInput'
ValueFor(fr).AssignValuesOf(InputRef(0).ValueFor(fr.WithLayout(InputRef(0).GetMBLayout()))); // just propagate through
// TODO: Once we do in-place, the above must include a copy-to-self check (either here or inside the matrix lib).
size_t rank = GetSampleLayout().GetRank();
auto result = ValueTensorFor(rank, fr);
auto input0 = InputRef(0).ValueTensorFor(rank, InputRef(0).GetMBLayout() ? fr.WithLayout(InputRef(0).GetMBLayout()) : fr.AllowBroadcast());
result.AssignCopyOf(input0);
// TODO: Once we do in-place, the above must include a copy-to-self check (either here or inside the tensor lib).
}
virtual void /*ComputationNode::*/ BackpropTo(const size_t inputIndex, const FrameRange& fr) override
{
if (inputIndex == 0)
InputRef(0).GradientFor(fr.WithLayout(InputRef(0).GetMBLayout())) += GradientFor(fr);
// TODO: Once we do in-place, the above must include a copy-to-self check (pay special attention to adding vs. copying).
{
size_t rank = GetSampleLayout().GetRank();
auto gradient = GradientTensorFor(rank, fr);
auto inputGradient = Input(inputIndex)->GradientTensorFor(rank, InputRef(inputIndex).GetMBLayout() ? fr.WithLayout(InputRef(inputIndex).GetMBLayout()) : fr.AllowBroadcast());
// if reduction then mask the respective input(s) (zero out the gaps)
if (Input(inputIndex)->ReducesInTimeWrt(shared_from_this()))
MaskMissingGradientColumnsToZero(fr);
if (Input(inputIndex)->ParentOverwritesGradient())
inputGradient.AssignCopyOf(gradient);
else
inputGradient.AddCopyOf(gradient);
// TODO: Once we do in-place, the above must include a copy-to-self check (pay special attention to adding vs. copying).
}
}
virtual bool OutputUsedInComputingInputNodesGradients() const override { return false; }
@ -292,12 +310,13 @@ public:
virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override
{
Base::Validate(isFinalValidationPass);
if (isFinalValidationPass && (!InputRef(0).HasMBLayout() || !InputRef(1).HasMBLayout()))
RuntimeError("%ls %ls operation requires two inputs that both have an associated MB layout.", NodeName().c_str(), OperationName().c_str());
if (isFinalValidationPass && !InputRef(1).HasMBLayout())
RuntimeError("%ls %ls operation requires the 2nd input to have an associated MB layout.", NodeName().c_str(), OperationName().c_str());
m_pMBLayout = InputRef(1).GetMBLayout(); // output layout is that of 'layoutInput'
// Note: We could also enforce that both inputs in fact have different layouts. But maybe there are edge cases where it isn't. Then this just becomes a nop. Also OK.
SetDims(Input(0));
SetDims(InputRef(0).GetSampleLayout(), HasMBLayout());
}
};