CNTK v2 library: Add the ability to broadcast an input without MBLayout along a specified layout, in the ReconcileDynamicAxis node
This commit is contained in:
Родитель
4f4b5fb52f
Коммит
9c8884eaa0
|
@ -231,6 +231,7 @@ namespace CNTK
|
|||
CNTK_API FunctionPtr Scatter(const Variable& operand, const Variable& condition, const std::pair<size_t, int>& newDerivedSequenceAxisScalingAndAdditiveFactor, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr Slice(const Variable& operand, const Axis& axis, int beginIndex, int endIndex, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr ReduceElements(const Variable& operand, const std::wstring& reductionOpName, const Axis& axis, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr ReconcileDynamicAxes(const Variable& operand, const Variable& axesAsOperand, const std::wstring& name = L"");
|
||||
|
||||
// This is meant for debugging purposes only and is very likely to be deprecated in the future.
|
||||
CNTK_API void SaveAsLegacyModel(const FunctionPtr& rootFunction, const std::wstring& modelFile);
|
||||
|
|
|
@ -1479,20 +1479,7 @@ namespace CNTK
|
|||
|
||||
FunctionPtr ZeroesWithDynamicAxesLike(const Variable& operand)
|
||||
{
|
||||
if (operand.IsSparse())
|
||||
{
|
||||
if (operand.Shape().Rank() > 1)
|
||||
LogicError("Internal::ZeroesWithDynamicAxesLike: Currently only 1D sparse inputs are supported!");
|
||||
|
||||
// TODO: A matrix multiplication is too expensive for something like this
|
||||
// Replace this with a cheaper operation.
|
||||
return Times(Constant({ 1, operand.Shape()[0] }, operand.GetDataType(), 0.0), operand);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto reduceAllStaticAxesFunc = Internal::ReduceElements(operand, PrimitiveFunction::InternalSumReductionOpName, Axis::AllStaticAxes());
|
||||
return Minus(reduceAllStaticAxesFunc, reduceAllStaticAxesFunc);
|
||||
}
|
||||
return Internal::ReconcileDynamicAxes(Constant::Scalar(0.0f), operand);
|
||||
}
|
||||
|
||||
FunctionPtr Where(const Variable& condition, const std::pair<size_t, int>& newDerivedSequenceAxisScalingAndAdditiveFactor, const std::wstring& name)
|
||||
|
@ -1548,5 +1535,11 @@ namespace CNTK
|
|||
|
||||
LogicError("CNTK::ReduceElements: Invalid axis argument provided. To reduce a sequence along its ordered dynamic axis use Sequence::ReduceElements.");
|
||||
}
|
||||
|
||||
|
||||
FunctionPtr ReconcileDynamicAxes(const Variable& operand, const Variable& axesAsOperand, const std::wstring& name)
|
||||
{
|
||||
return BinaryOp(PrimitiveOpType::ReconcileDynamicAxis, operand, axesAsOperand, Dictionary(), name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -670,8 +670,6 @@ namespace CNTK
|
|||
assert(m_inputs.size() == 2);
|
||||
auto operand = m_inputs[0];
|
||||
auto layout = m_inputs[1];
|
||||
if (operand.DynamicAxes().empty())
|
||||
InvalidArgument("ReconcileDynamicAxis: input must have at least one dynamic axis");
|
||||
if (layout.DynamicAxes().empty())
|
||||
InvalidArgument("ReconcileDynamicAxis: layout must have at least one dynamic axis");
|
||||
outputShape = operand.Shape();
|
||||
|
|
|
@ -246,7 +246,8 @@ private:
|
|||
// -----------------------------------------------------------------------
|
||||
// ReconcileDynamicAxis (dataInput, layoutInput)
|
||||
// This node copies data from 'dataInput' while it propagates the minibatch-layout information from 'layoutInput'.
|
||||
// It does perform a runtime check to enforce that the layout of 'dataInput' is compatible (identical content) to that of 'layoutInput'.
|
||||
// If 'dataInput' does not have a MBLayout, it broadcasts the dataInput along the MBLayout of the 'layoutInput'.
|
||||
// It does perform a runtime check to enforce that the layout of 'dataInput' is compatible to that of 'layoutInput'.
|
||||
// This node is meant to be used from BrainScript macros that bracket expand/reduce pairs of nodes. It is not meant to really be used directly.
|
||||
// TODO: What to do with sequence-boundary flags?
|
||||
// -----------------------------------------------------------------------
|
||||
|
@ -268,22 +269,39 @@ public:
|
|||
{
|
||||
// enforce compatibility of 'dataInput' with 'layoutInput'
|
||||
// TODO: how to deal with boundary flags?
|
||||
if (*m_pMBLayout != *InputRef(0).GetMBLayout()) // this does a deep value-level comparison
|
||||
if (InputRef(0).GetMBLayout() && (*m_pMBLayout != *InputRef(0).GetMBLayout())) // this does a deep value-level comparison
|
||||
InvalidArgument("%ls %ls operation discovered that %ls %ls operation produced an MB layout that is incompatible with that of %ls %ls.",
|
||||
NodeName().c_str(), OperationName().c_str(),
|
||||
InputRef(0).NodeName().c_str(), InputRef(0).OperationName().c_str(),
|
||||
InputRef(1).NodeName().c_str(), InputRef(1).OperationName().c_str());
|
||||
|
||||
// copy the data from 'dataInput'
|
||||
ValueFor(fr).AssignValuesOf(InputRef(0).ValueFor(fr.WithLayout(InputRef(0).GetMBLayout()))); // just propagate through
|
||||
// TODO: Once we do in-place, the above must include a copy-to-self check (either here or inside the matrix lib).
|
||||
size_t rank = GetSampleLayout().GetRank();
|
||||
auto result = ValueTensorFor(rank, fr);
|
||||
auto input0 = InputRef(0).ValueTensorFor(rank, InputRef(0).GetMBLayout() ? fr.WithLayout(InputRef(0).GetMBLayout()) : fr.AllowBroadcast());
|
||||
result.AssignCopyOf(input0);
|
||||
// TODO: Once we do in-place, the above must include a copy-to-self check (either here or inside the tensor lib).
|
||||
}
|
||||
|
||||
virtual void /*ComputationNode::*/ BackpropTo(const size_t inputIndex, const FrameRange& fr) override
|
||||
{
|
||||
if (inputIndex == 0)
|
||||
InputRef(0).GradientFor(fr.WithLayout(InputRef(0).GetMBLayout())) += GradientFor(fr);
|
||||
// TODO: Once we do in-place, the above must include a copy-to-self check (pay special attention to adding vs. copying).
|
||||
{
|
||||
size_t rank = GetSampleLayout().GetRank();
|
||||
auto gradient = GradientTensorFor(rank, fr);
|
||||
auto inputGradient = Input(inputIndex)->GradientTensorFor(rank, InputRef(inputIndex).GetMBLayout() ? fr.WithLayout(InputRef(inputIndex).GetMBLayout()) : fr.AllowBroadcast());
|
||||
|
||||
// if reduction then mask the respective input(s) (zero out the gaps)
|
||||
if (Input(inputIndex)->ReducesInTimeWrt(shared_from_this()))
|
||||
MaskMissingGradientColumnsToZero(fr);
|
||||
|
||||
if (Input(inputIndex)->ParentOverwritesGradient())
|
||||
inputGradient.AssignCopyOf(gradient);
|
||||
else
|
||||
inputGradient.AddCopyOf(gradient);
|
||||
|
||||
// TODO: Once we do in-place, the above must include a copy-to-self check (pay special attention to adding vs. copying).
|
||||
}
|
||||
}
|
||||
|
||||
virtual bool OutputUsedInComputingInputNodesGradients() const override { return false; }
|
||||
|
@ -292,12 +310,13 @@ public:
|
|||
virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override
|
||||
{
|
||||
Base::Validate(isFinalValidationPass);
|
||||
if (isFinalValidationPass && (!InputRef(0).HasMBLayout() || !InputRef(1).HasMBLayout()))
|
||||
RuntimeError("%ls %ls operation requires two inputs that both have an associated MB layout.", NodeName().c_str(), OperationName().c_str());
|
||||
if (isFinalValidationPass && !InputRef(1).HasMBLayout())
|
||||
RuntimeError("%ls %ls operation requires the 2nd input to have an associated MB layout.", NodeName().c_str(), OperationName().c_str());
|
||||
|
||||
m_pMBLayout = InputRef(1).GetMBLayout(); // output layout is that of 'layoutInput'
|
||||
// Note: We could also enforce that both inputs in fact have different layouts. But maybe there are edge cases where it isn't. Then this just becomes a nop. Also OK.
|
||||
|
||||
SetDims(Input(0));
|
||||
SetDims(InputRef(0).GetSampleLayout(), HasMBLayout());
|
||||
}
|
||||
};
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче