Integrate amitaga/v2Beta3 into master
This commit is contained in:
Коммит
76aca3029c
|
@ -2399,6 +2399,11 @@ namespace CNTK
|
|||
///
|
||||
CNTK_API static FunctionPtr LoadModel(DataType dataType, const std::wstring& modelFile, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
|
||||
|
||||
///
|
||||
/// Prints the entire graph underlying this function to stderr
|
||||
///
|
||||
CNTK_API void PrintGraph() const;
|
||||
|
||||
private:
|
||||
|
||||
template <typename VariableType, typename FilterFunction>
|
||||
|
@ -2899,6 +2904,13 @@ namespace CNTK
|
|||
CNTK_API FunctionPtr IsFirst(const Variable& operand, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr IsLast(const Variable& operand, const std::wstring& name = L"");
|
||||
|
||||
CNTK_API FunctionPtr Slice(const Variable& operand, int beginIndex, int endIndex, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in sum reduction operation on specified tensor input operand along the operands lone dynamic sequence axis
|
||||
///
|
||||
CNTK_API FunctionPtr ReduceSum(const Variable& operand, const std::wstring& name = L"");
|
||||
|
||||
CNTK_API FunctionPtr First(const Variable& operand, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr Last(const Variable& operand, const std::wstring& name = L"");
|
||||
|
||||
|
|
|
@ -206,9 +206,11 @@ namespace CNTK
|
|||
CNTK_API FunctionPtr GatherPacked(const Variable& operand, const Variable& packedIndex, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr ScatterPacked(const Variable& operand, const Variable& packedIndex, const Variable& condition, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr ZeroesWithDynamicAxesLike(const Variable& operand);
|
||||
CNTK_API FunctionPtr Where(const Variable& condition, const std::vector<Axis>& newDynamicAxes, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr Gather(const Variable& operand, const Variable& condition, const std::vector<Axis>& newDynamicAxes, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr Scatter(const Variable& operand, const Variable& condition, const std::vector<Axis>& newDynamicAxes, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr Where(const Variable& condition, const std::pair<size_t, int>& newDerivedSequenceAxisScalingAndAdditiveFactor, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr Gather(const Variable& operand, const Variable& condition, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr Gather(const Variable& operand, const Variable& condition, const std::pair<size_t, int>& newDerivedSequenceAxisScalingAndAdditiveFactor, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr Scatter(const Variable& operand, const Variable& condition, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr Scatter(const Variable& operand, const Variable& condition, const std::pair<size_t, int>& newDerivedSequenceAxisScalingAndAdditiveFactor, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr Slice(const Variable& operand, const Axis& axis, int beginIndex, int endIndex, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr ReduceElements(const Variable& operand, const std::wstring& reductionOpName, const Axis& axis, const std::wstring& name = L"");
|
||||
|
||||
|
|
|
@ -544,6 +544,12 @@ namespace CNTK
|
|||
return CompositeFunction::Deserialize(modelDictionary, device);
|
||||
}
|
||||
|
||||
void Function::PrintGraph() const
|
||||
{
|
||||
CompositeFunction::Traverse(RootFunction(), [](const FunctionPtr& function) {
|
||||
});
|
||||
}
|
||||
|
||||
// Names for the reduction operations as used by the CNTK ReduceElementsNode
|
||||
/*static*/ const std::wstring PrimitiveFunction::InternalSumReductionOpName = L"Sum";
|
||||
/*static*/ const std::wstring PrimitiveFunction::InternalLogSumReductionOpName = L"LogSum";
|
||||
|
@ -580,6 +586,8 @@ namespace CNTK
|
|||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameEpsilon = L"epsilon";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameUseCuDNNEngine = L"useCuDNNEngine";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameNewDynamicAxes = L"newDynamicAxes";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameNewSequenceAxisLengthScalingFactor = L"newSequenceAxisLengthScalingFactor";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameNewSequenceAxisLengthAdditiveFactor = L"newSequenceAxisLengthAdditiveFactor";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameBeginIndex = L"beginIndex";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameEndIndex = L"endIndex";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameReductionOpName = L"reductionOpName";
|
||||
|
@ -631,7 +639,36 @@ namespace CNTK
|
|||
if ((op == PrimitiveOpType::SumAll) || (op == PrimitiveOpType::SquaredError) || (op == PrimitiveOpType::CrossEntropyWithSoftmax) || (op == PrimitiveOpType::ClassificationError))
|
||||
outputDynamicAxes = std::vector<Axis>({});
|
||||
else if (op == PrimitiveOpType::Where)
|
||||
outputDynamicAxes = AsVector<Axis>(functionConfig[PrimitiveFunction::AttributeNameNewDynamicAxes].Value<std::vector<DictionaryValue>>());
|
||||
{
|
||||
if (functionConfig.Contains(PrimitiveFunction::AttributeNameNewDynamicAxes))
|
||||
outputDynamicAxes = AsVector<Axis>(functionConfig[PrimitiveFunction::AttributeNameNewDynamicAxes].Value<std::vector<DictionaryValue>>());
|
||||
else
|
||||
{
|
||||
if (inputs[0].DynamicAxes() == Axis::UnknownDynamicAxes())
|
||||
outputDynamicAxes = Axis::UnknownDynamicAxes();
|
||||
else
|
||||
{
|
||||
if (functionConfig.Contains(PrimitiveFunction::AttributeNameNewSequenceAxisLengthScalingFactor) &&
|
||||
functionConfig.Contains(PrimitiveFunction::AttributeNameNewSequenceAxisLengthAdditiveFactor))
|
||||
{
|
||||
size_t newSequenceAxisLengthScalingFactor = functionConfig[PrimitiveFunction::AttributeNameNewSequenceAxisLengthScalingFactor].Value<size_t>();
|
||||
int newSequenceAxisLengthAdditiveFactor = functionConfig[PrimitiveFunction::AttributeNameNewSequenceAxisLengthAdditiveFactor].Value<int>();
|
||||
|
||||
auto derivedDynamicAxes = GetDerivedDynamicAxes(inputs[0].DynamicAxes()[0], newSequenceAxisLengthScalingFactor, newSequenceAxisLengthAdditiveFactor);
|
||||
std::copy(derivedDynamicAxes.begin(), derivedDynamicAxes.end(), std::back_inserter(outputDynamicAxes));
|
||||
}
|
||||
else
|
||||
{
|
||||
outputDynamicAxes.push_back(Axis::NewUniqueDynamicAxis(L"whereNodeDynamicAxis"));
|
||||
}
|
||||
|
||||
for (size_t i = 1; i < inputs[0].DynamicAxes().size(); ++i)
|
||||
outputDynamicAxes.push_back(inputs[0].DynamicAxes()[i]);
|
||||
|
||||
functionConfig[PrimitiveFunction::AttributeNameNewDynamicAxes] = AsDictionaryValueVector(outputDynamicAxes);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (op == PrimitiveOpType::ScatterPacked)
|
||||
outputDynamicAxes = inputs[2].DynamicAxes();
|
||||
else if ((op == PrimitiveOpType::PackedIndex) || (op == PrimitiveOpType::GatherPacked))
|
||||
|
@ -1098,7 +1135,7 @@ namespace CNTK
|
|||
std::vector<FunctionPtr> topoSortedPrimitiveFunctions;
|
||||
std::vector<Variable> inputs;
|
||||
std::unordered_set<std::wstring> inputUids;
|
||||
Traverse([&visitedFunctions, &inputs, &topoSortedPrimitiveFunctions, &inputUids](const FunctionPtr& function) {
|
||||
Traverse(RootFunction(), [&visitedFunctions, &inputs, &topoSortedPrimitiveFunctions, &inputUids](const FunctionPtr& function) {
|
||||
std::vector<Variable> functionInputs = function->Inputs();
|
||||
for (const auto& input : functionInputs)
|
||||
{
|
||||
|
@ -2585,7 +2622,7 @@ namespace CNTK
|
|||
|
||||
FunctionPtr Round(const Variable& operand, const std::wstring& name)
|
||||
{
|
||||
return Floor(Plus(operand, Constant::Scalar(operand.GetDataType(), 0.5)), name);
|
||||
return Floor(Plus(operand, Constant::Scalar(0.5f)), name);
|
||||
}
|
||||
|
||||
FunctionPtr Floor(const Variable& operand, const std::wstring& name)
|
||||
|
@ -2633,11 +2670,9 @@ namespace CNTK
|
|||
|
||||
return TransposeAxes(operand, Axis(0), Axis(1), name);
|
||||
}
|
||||
|
||||
FunctionPtr Slice(const Variable& operand, const Axis& axis, int beginIndex, int endIndex, const std::wstring& name)
|
||||
{
|
||||
if (axis == Axis::DefaultBatchAxis())
|
||||
LogicError("Slice is currently unsupported along the batch axis");
|
||||
|
||||
if (axis.IsStaticAxis())
|
||||
{
|
||||
if ((endIndex - beginIndex) <= 0)
|
||||
|
@ -2646,46 +2681,10 @@ namespace CNTK
|
|||
return Internal::Slice(operand, axis, beginIndex, endIndex, name);
|
||||
}
|
||||
|
||||
if ((beginIndex == 0) && (endIndex == 0))
|
||||
return operand;
|
||||
if (axis == Axis::DefaultBatchAxis())
|
||||
LogicError("Slice is currently unsupported along the batch axis");
|
||||
|
||||
auto operandAxes = operand.DynamicAxes();
|
||||
auto findAxis = std::find(operandAxes.begin(), operandAxes.end(), axis);
|
||||
if (findAxis == operandAxes.end())
|
||||
InvalidArgument("The specified dynamic axis named %S does not match any of the dynamic axes of the operand", axis.Name().c_str());
|
||||
|
||||
auto beginFlagsLambda = [beginIndex, operand]() {
|
||||
return (beginIndex > 0) ? Minus(Constant::Scalar(operand.GetDataType(), 1.0), Internal::IsWithin(operand, beginIndex)) : Internal::IsWithin(operand, beginIndex);
|
||||
};
|
||||
|
||||
auto endFlagsLambda = [endIndex, operand]() {
|
||||
return (endIndex > 0) ? Internal::IsWithin(operand, endIndex) : Minus(Constant::Scalar(operand.GetDataType(), 1.0), Internal::IsWithin(operand, endIndex));
|
||||
};
|
||||
|
||||
FunctionPtr flags;
|
||||
if (beginIndex == 0)
|
||||
flags = endFlagsLambda();
|
||||
else if (endIndex == 0)
|
||||
flags = beginFlagsLambda();
|
||||
else
|
||||
flags = ElementTimes(beginFlagsLambda(), endFlagsLambda());
|
||||
|
||||
// Since we are slicing along a dynamic axis, the output variable's dynamic axes will be different than the operand
|
||||
std::vector<Axis> newDynamicAxes;
|
||||
for (auto operandAxis : operandAxes)
|
||||
{
|
||||
if (operandAxis == axis)
|
||||
{
|
||||
int sliceLength = (endIndex - beginIndex);
|
||||
size_t multiplicativeFactor = (sliceLength > 0) ? 0 : 1;
|
||||
auto derivedDynamicAxes = GetDerivedDynamicAxes(operandAxis, multiplicativeFactor, sliceLength);
|
||||
std::copy(derivedDynamicAxes.begin(), derivedDynamicAxes.end(), std::back_inserter(newDynamicAxes));
|
||||
}
|
||||
else
|
||||
newDynamicAxes.push_back(operandAxis);
|
||||
}
|
||||
|
||||
return Internal::Gather(operand, flags, newDynamicAxes, name);
|
||||
LogicError("CNTK::Slice: Invalid axis argument provided. To slice a sequence along its ordered dynamic axis use Sequence::Slice.");
|
||||
}
|
||||
|
||||
FunctionPtr RandomSample(const Variable& operand, size_t numSamples, bool allowDuplicates, const std::wstring& name)
|
||||
|
@ -2721,6 +2720,7 @@ namespace CNTK
|
|||
|
||||
return UnaryOp(PrimitiveOpType::Reshape, operand, std::move(additionalProperties), name);
|
||||
}
|
||||
|
||||
FunctionPtr BinaryOp(PrimitiveOpType op, const Variable& leftOperand, const Variable& rightOperand, Dictionary&& opConfig, const std::wstring& name)
|
||||
{
|
||||
std::vector<Variable> operands = { leftOperand, rightOperand };
|
||||
|
@ -2815,14 +2815,14 @@ namespace CNTK
|
|||
if (topN == 1)
|
||||
{
|
||||
if (axis == Axis(0))
|
||||
return Minus(Constant::Scalar(prediction.GetDataType(), 1.0), TransposeTimes(labels, Hardmax(prediction)), name);
|
||||
return Minus(Constant::Scalar(1.0f), TransposeTimes(labels, Hardmax(prediction)), name);
|
||||
else
|
||||
{
|
||||
auto axMax = ReduceMax(prediction, axis);
|
||||
auto pred = Equal(prediction, axMax);
|
||||
auto wrongPred = NotEqual(labels, pred);
|
||||
auto axErr = ReduceSum(wrongPred, axis);
|
||||
auto capErr = GreaterEqual(axErr, Constant::Scalar(prediction.GetDataType(), 1.0));
|
||||
auto capErr = GreaterEqual(axErr, Constant::Scalar(1.0f));
|
||||
return ReduceMean(capErr, Axis::AllStaticAxes(), name);
|
||||
}
|
||||
}
|
||||
|
@ -2831,7 +2831,7 @@ namespace CNTK
|
|||
if (axis != Axis(0))
|
||||
LogicError("ClassificationError along a specific axis does not support topN!");
|
||||
|
||||
std::vector<Variable> operands = { prediction, labels, Constant::Scalar(prediction.GetDataType(), (double)topN) };
|
||||
std::vector<Variable> operands = { prediction, labels, Constant::Scalar((float)topN) };
|
||||
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::ClassificationError, operands, Dictionary(), name), name);
|
||||
}
|
||||
}
|
||||
|
@ -3011,75 +3011,113 @@ namespace CNTK
|
|||
{
|
||||
// TODO: This is a temporary and expensive hack until we have a real alias implementation
|
||||
// that does not waste memory and compute cycles
|
||||
return Plus(operand, Constant::Scalar(operand.GetDataType(), 0), name);
|
||||
return Plus(operand, Constant::Scalar(0.0f), name);
|
||||
}
|
||||
|
||||
namespace Sequence
|
||||
{
|
||||
void VerifyIsSequence(const Variable& operand)
|
||||
{
|
||||
// The operand must have at least one dynamic axis and its first dynamic axis must be ordered
|
||||
if (operand.DynamicAxes().empty() || !operand.DynamicAxes()[0].IsOrdered())
|
||||
// The operand must have at least one dynamic axis
|
||||
if (operand.DynamicAxes().empty())
|
||||
InvalidArgument("A sequence function can only be applied on operands with at least one dynamic axis and whose first dynamic axis is ordered");
|
||||
}
|
||||
|
||||
FunctionPtr IsFirst(const Variable& operand, const std::wstring& name)
|
||||
{
|
||||
VerifyIsSequence(operand);
|
||||
return Internal::IsWithin(operand, 1, name);
|
||||
}
|
||||
|
||||
FunctionPtr IsLast(const Variable& operand, const std::wstring& name)
|
||||
{
|
||||
VerifyIsSequence(operand);
|
||||
return Internal::IsWithin(operand, -1, name);
|
||||
}
|
||||
|
||||
FunctionPtr Slice(const Variable& operand, int beginIndex, int endIndex, const std::wstring& name)
|
||||
{
|
||||
VerifyIsSequence(operand);
|
||||
|
||||
if ((beginIndex == 0) && (endIndex == 0))
|
||||
return operand;
|
||||
|
||||
auto beginFlagsLambda = [beginIndex, operand]() {
|
||||
return (beginIndex > 0) ? Minus(Constant::Scalar(1.0f), Internal::IsWithin(operand, beginIndex)) : Internal::IsWithin(operand, beginIndex);
|
||||
};
|
||||
|
||||
auto endFlagsLambda = [endIndex, operand]() {
|
||||
return (endIndex > 0) ? Internal::IsWithin(operand, endIndex) : Minus(Constant::Scalar(1.0f), Internal::IsWithin(operand, endIndex));
|
||||
};
|
||||
|
||||
FunctionPtr flags;
|
||||
if (beginIndex == 0)
|
||||
flags = endFlagsLambda();
|
||||
else if (endIndex == 0)
|
||||
flags = beginFlagsLambda();
|
||||
else
|
||||
flags = ElementTimes(beginFlagsLambda(), endFlagsLambda());
|
||||
|
||||
int sliceLength = (endIndex - beginIndex);
|
||||
size_t multiplicativeFactor = (sliceLength > 0) ? 0 : 1;
|
||||
|
||||
return Internal::Gather(operand, flags, { multiplicativeFactor, sliceLength }, name);
|
||||
}
|
||||
|
||||
FunctionPtr First(const Variable& operand, const std::wstring& name)
|
||||
{
|
||||
VerifyIsSequence(operand);
|
||||
return Slice(operand, operand.DynamicAxes()[0], 0, 1, name);
|
||||
return Sequence::Slice(operand, 0, 1, name);
|
||||
}
|
||||
|
||||
FunctionPtr Last(const Variable& operand, const std::wstring& name)
|
||||
{
|
||||
VerifyIsSequence(operand);
|
||||
return Slice(operand, operand.DynamicAxes()[0], -1, 0, name);
|
||||
}
|
||||
|
||||
std::vector<Axis> WhereOpDynamicAxes(const Variable& operand)
|
||||
{
|
||||
VerifyIsSequence(operand);
|
||||
|
||||
std::vector<Axis> newDynamicAxes = { Axis::NewUniqueDynamicAxis(L"whereNodeDynamicAxis") };
|
||||
for (size_t i = 1; i < operand.DynamicAxes().size(); ++i)
|
||||
newDynamicAxes.push_back(operand.DynamicAxes()[i]);
|
||||
|
||||
return newDynamicAxes;
|
||||
return Sequence::Slice(operand, -1, 0, name);
|
||||
}
|
||||
|
||||
FunctionPtr Where(const Variable& condition, const std::wstring& name)
|
||||
{
|
||||
return Internal::Where(condition, WhereOpDynamicAxes(condition), name);
|
||||
return UnaryOp(PrimitiveOpType::Where, condition, Dictionary(), name);
|
||||
}
|
||||
|
||||
FunctionPtr Gather(const Variable& operand, const Variable& condition, const std::wstring& name)
|
||||
{
|
||||
return Internal::Gather(operand, condition, WhereOpDynamicAxes(condition), name);
|
||||
return Internal::Gather(operand, condition, name);
|
||||
}
|
||||
|
||||
FunctionPtr Scatter(const Variable& operand, const Variable& condition, const std::wstring& name)
|
||||
{
|
||||
return Internal::Scatter(operand, condition, WhereOpDynamicAxes(condition), name);
|
||||
return Internal::Scatter(operand, condition, name);
|
||||
}
|
||||
|
||||
FunctionPtr BroadcastAs(const Variable& operand, const Variable& broadcastAs, const std::wstring& name)
|
||||
{
|
||||
auto dataPadded = Internal::Scatter(operand, Sequence::IsFirst(broadcastAs), operand.DynamicAxes());
|
||||
auto dataPadded = Internal::Scatter(operand, Sequence::IsFirst(broadcastAs), std::make_pair<size_t, int>(0, 1));
|
||||
auto placeHolderOutput = PlaceholderVariable(operand.Shape(), broadcastAs.DynamicAxes());
|
||||
auto output = ElementSelect(Sequence::IsFirst(broadcastAs), dataPadded, PastValue(placeHolderOutput), name);
|
||||
return output->ReplacePlaceholders({ { placeHolderOutput, output } });
|
||||
}
|
||||
|
||||
FunctionPtr ReduceElements(const Variable& operand, const std::wstring& reductionOpName, const std::wstring& name)
|
||||
{
|
||||
using namespace std::placeholders;
|
||||
|
||||
std::function<FunctionPtr(const Variable& leftOperand, const Variable& rightOperand)> reductionFunctor;
|
||||
if (reductionOpName == PrimitiveFunction::InternalSumReductionOpName)
|
||||
reductionFunctor = std::bind(Plus, _1, _2, L"");
|
||||
else
|
||||
LogicError("%S reduction along dynamic axis is currently unsupported", reductionOpName.c_str());
|
||||
|
||||
// We are reducing over a dynamic axis which is currently implemented using recurrence
|
||||
auto cumulativeSumFunctionPlaceholder = PlaceholderVariable(operand.Shape());
|
||||
auto prevAccumulatedValuesFunction = PastValue(cumulativeSumFunctionPlaceholder);
|
||||
auto cumulativeSumFunction = reductionFunctor(prevAccumulatedValuesFunction, operand);
|
||||
cumulativeSumFunction->ReplacePlaceholders({ { cumulativeSumFunctionPlaceholder, cumulativeSumFunction } });
|
||||
|
||||
return Sequence::Slice(cumulativeSumFunction, -1, 0, name);
|
||||
}
|
||||
|
||||
FunctionPtr ReduceSum(const Variable& operand, const std::wstring& name)
|
||||
{
|
||||
return ReduceElements(operand, PrimitiveFunction::InternalSumReductionOpName, name);
|
||||
}
|
||||
}
|
||||
|
||||
namespace Internal
|
||||
|
@ -3092,9 +3130,9 @@ namespace CNTK
|
|||
InvalidArgument("CNTK::Sequence::IsWithin: The offset must be positive");
|
||||
|
||||
if (offset > 0)
|
||||
return PastValue(Internal::ZeroesWithDynamicAxesLike(operand), Constant::Scalar(operand.GetDataType(), 1.0), offset, name);
|
||||
return PastValue(Internal::ZeroesWithDynamicAxesLike(operand), Constant::Scalar(1.0f), offset, name);
|
||||
else
|
||||
return FutureValue(Internal::ZeroesWithDynamicAxesLike(operand), Constant::Scalar(operand.GetDataType(), 1.0), -offset, name);
|
||||
return FutureValue(Internal::ZeroesWithDynamicAxesLike(operand), Constant::Scalar(1.0f), -offset, name);
|
||||
}
|
||||
|
||||
FunctionPtr PackedIndex(const Variable& operand, const Variable& index, const std::wstring& name)
|
||||
|
@ -3131,21 +3169,32 @@ namespace CNTK
|
|||
}
|
||||
}
|
||||
|
||||
FunctionPtr Where(const Variable& condition, const std::vector<Axis>& newDynamicAxes, const std::wstring& name)
|
||||
FunctionPtr Where(const Variable& condition, const std::pair<size_t, int>& newDerivedSequenceAxisScalingAndAdditiveFactor, const std::wstring& name)
|
||||
{
|
||||
auto additionalProperties = Dictionary();
|
||||
additionalProperties[PrimitiveFunction::AttributeNameNewDynamicAxes] = AsDictionaryValueVector(newDynamicAxes);
|
||||
additionalProperties[PrimitiveFunction::AttributeNameNewSequenceAxisLengthScalingFactor] = newDerivedSequenceAxisScalingAndAdditiveFactor.first;
|
||||
additionalProperties[PrimitiveFunction::AttributeNameNewSequenceAxisLengthAdditiveFactor] = newDerivedSequenceAxisScalingAndAdditiveFactor.second;
|
||||
return UnaryOp(PrimitiveOpType::Where, condition, std::move(additionalProperties), name);
|
||||
}
|
||||
|
||||
FunctionPtr Gather(const Variable& operand, const Variable& condition, const std::vector<Axis>& newDynamicAxes, const std::wstring& name)
|
||||
FunctionPtr Gather(const Variable& operand, const Variable& condition, const std::wstring& name)
|
||||
{
|
||||
return Internal::GatherPacked(operand, Internal::PackedIndex(/*layout of*/ operand, Where(condition, newDynamicAxes)), name);
|
||||
return Internal::GatherPacked(operand, Internal::PackedIndex(/*layout of*/ operand, Sequence::Where(condition)), name);
|
||||
}
|
||||
|
||||
FunctionPtr Scatter(const Variable& operand, const Variable& condition, const std::vector<Axis>& whereNodeDynamicAxes, const std::wstring& name)
|
||||
FunctionPtr Gather(const Variable& operand, const Variable& condition, const std::pair<size_t, int>& newDerivedSequenceAxisScalingAndAdditiveFactor, const std::wstring& name)
|
||||
{
|
||||
return Internal::ScatterPacked(operand, Internal::PackedIndex(/*layout of*/ condition, Where(condition, whereNodeDynamicAxes)), /*layout of*/ condition, name);
|
||||
return Internal::GatherPacked(operand, Internal::PackedIndex(/*layout of*/ operand, Where(condition, newDerivedSequenceAxisScalingAndAdditiveFactor)), name);
|
||||
}
|
||||
|
||||
FunctionPtr Scatter(const Variable& operand, const Variable& condition, const std::wstring& name)
|
||||
{
|
||||
return Internal::ScatterPacked(operand, Internal::PackedIndex(/*layout of*/ condition, Sequence::Where(condition)), /*layout of*/ condition, name);
|
||||
}
|
||||
|
||||
FunctionPtr Scatter(const Variable& operand, const Variable& condition, const std::pair<size_t, int>& newDerivedSequenceAxisScalingAndAdditiveFactor, const std::wstring& name)
|
||||
{
|
||||
return Internal::ScatterPacked(operand, Internal::PackedIndex(/*layout of*/ condition, Where(condition, newDerivedSequenceAxisScalingAndAdditiveFactor)), /*layout of*/ condition, name);
|
||||
}
|
||||
|
||||
FunctionPtr Slice(const Variable& operand, const Axis& axis, int beginIndex, int endIndex, const std::wstring& name)
|
||||
|
@ -3160,8 +3209,6 @@ namespace CNTK
|
|||
|
||||
FunctionPtr ReduceElements(const Variable& operand, const std::wstring& reductionOpName, const Axis& axis, const std::wstring& name)
|
||||
{
|
||||
using namespace std::placeholders;
|
||||
|
||||
if (axis.IsStaticAxis() || (axis == Axis::AllStaticAxes()))
|
||||
{
|
||||
auto additionalProperties = Dictionary();
|
||||
|
@ -3173,20 +3220,7 @@ namespace CNTK
|
|||
if (axis == Axis::DefaultBatchAxis())
|
||||
LogicError("Reduction is currently unsupported along the batch axis");
|
||||
|
||||
if (reductionOpName != PrimitiveFunction::InternalSumReductionOpName)
|
||||
LogicError("%S reduction along dynamic axis is currently unsupported", reductionOpName.c_str());
|
||||
|
||||
std::function<FunctionPtr(const Variable& leftOperand, const Variable& rightOperand)> reductionFunctor;
|
||||
if (reductionOpName == PrimitiveFunction::InternalSumReductionOpName)
|
||||
reductionFunctor = std::bind(Plus, _1, _2, L"");
|
||||
|
||||
// We are reducing over a dynamic axis which is currently implemented using recurrence
|
||||
auto cumulativeSumFunctionPlaceholder = PlaceholderVariable(operand.Shape());
|
||||
auto prevAccumulatedValuesFunction = PastValue(cumulativeSumFunctionPlaceholder);
|
||||
auto cumulativeSumFunction = reductionFunctor(prevAccumulatedValuesFunction, operand);
|
||||
cumulativeSumFunction->ReplacePlaceholders({ { cumulativeSumFunctionPlaceholder, cumulativeSumFunction } });
|
||||
|
||||
return CNTK::Slice(cumulativeSumFunction, axis, -1, 0, name);
|
||||
LogicError("CNTK::ReduceElements: Invalid axis argument provided. To reduce a sequence along its ordered dynamic axis use Sequence::ReduceElements.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -187,6 +187,8 @@ namespace CNTK
|
|||
static const std::wstring AttributeNameEpsilon;
|
||||
static const std::wstring AttributeNameUseCuDNNEngine;
|
||||
static const std::wstring AttributeNameNewDynamicAxes;
|
||||
static const std::wstring AttributeNameNewSequenceAxisLengthScalingFactor;
|
||||
static const std::wstring AttributeNameNewSequenceAxisLengthAdditiveFactor;
|
||||
static const std::wstring AttributeNameBeginIndex;
|
||||
static const std::wstring AttributeNameEndIndex;
|
||||
static const std::wstring AttributeNameReductionOpName;
|
||||
|
@ -699,22 +701,11 @@ namespace CNTK
|
|||
return CompositeFunctionOpName;
|
||||
}
|
||||
|
||||
private:
|
||||
virtual void ReplacePlaceholdersInPlace(const std::unordered_map<Variable, Variable>& placeholderReplacements,
|
||||
std::unordered_set<const Function*>& visitedFunctions,
|
||||
std::unordered_set<Variable>& replacedPlaceholders) override;
|
||||
|
||||
CompositeFunction(const FunctionPtr& rootFunction, std::unordered_set<FunctionPtr>&& allPrimitiveFunctions, const std::wstring& name, const std::wstring& uid = Internal::GenerateUid(L"CompositeFunction"))
|
||||
: Function({}, rootFunction->Outputs(), Dictionary(), rootFunction, name, uid),
|
||||
m_allPrimitiveFunctions(std::move(allPrimitiveFunctions)), m_networkMatricesAllocated(false)
|
||||
{}
|
||||
|
||||
template <typename FunctionType>
|
||||
void Traverse(const FunctionType& functor) const
|
||||
static void Traverse(const FunctionPtr& rootFunction, const FunctionType& functor)
|
||||
{
|
||||
const auto& root = RootFunction();
|
||||
std::unordered_set<FunctionPtr> visitedFunctions;
|
||||
Traverse(root, visitedFunctions, functor);
|
||||
Traverse(rootFunction, visitedFunctions, functor);
|
||||
}
|
||||
|
||||
// Recursively traverses the Function graph underlying the 'rootFunction' invoking the provided functor for all visited nodes in the graph.
|
||||
|
@ -735,6 +726,16 @@ namespace CNTK
|
|||
}
|
||||
}
|
||||
|
||||
private:
|
||||
virtual void ReplacePlaceholdersInPlace(const std::unordered_map<Variable, Variable>& placeholderReplacements,
|
||||
std::unordered_set<const Function*>& visitedFunctions,
|
||||
std::unordered_set<Variable>& replacedPlaceholders) override;
|
||||
|
||||
CompositeFunction(const FunctionPtr& rootFunction, std::unordered_set<FunctionPtr>&& allPrimitiveFunctions, const std::wstring& name, const std::wstring& uid = Internal::GenerateUid(L"CompositeFunction"))
|
||||
: Function({}, rootFunction->Outputs(), Dictionary(), rootFunction, name, uid),
|
||||
m_allPrimitiveFunctions(std::move(allPrimitiveFunctions)), m_networkMatricesAllocated(false)
|
||||
{}
|
||||
|
||||
std::vector<Variable> DetermineInputs() const
|
||||
{
|
||||
const auto& root = RootFunction();
|
||||
|
|
|
@ -200,7 +200,6 @@ inline CNTK::FunctionPtr Stabilize(const CNTK::Variable& x, const CNTK::DeviceDe
|
|||
template <typename ElementType>
|
||||
std::pair<CNTK::FunctionPtr, CNTK::FunctionPtr> LSTMPCellWithSelfStabilization(CNTK::Variable input, CNTK::Variable prevOutput, CNTK::Variable prevCellState, const CNTK::DeviceDescriptor& device)
|
||||
{
|
||||
size_t inputDim = input.Shape()[0];
|
||||
size_t outputDim = prevOutput.Shape()[0];
|
||||
size_t cellDim = prevCellState.Shape()[0];
|
||||
|
||||
|
@ -209,8 +208,8 @@ std::pair<CNTK::FunctionPtr, CNTK::FunctionPtr> LSTMPCellWithSelfStabilization(C
|
|||
};
|
||||
|
||||
unsigned long seed = 1;
|
||||
auto createProjectionParam = [device, &seed](size_t outputDim, size_t inputDim) {
|
||||
return CNTK::Parameter({ outputDim, inputDim }, CNTK::AsDataType<ElementType>(), CNTK::GlorotUniformInitializer(1, 0, 1, seed++), device);
|
||||
auto createProjectionParam = [device, &seed](size_t outputDim) {
|
||||
return CNTK::Parameter({ outputDim, CNTK::NDShape::InferredDimension }, CNTK::AsDataType<ElementType>(), CNTK::GlorotUniformInitializer(1, 0, 1, seed++), device);
|
||||
};
|
||||
|
||||
auto createDiagWeightParam = [device, &seed](size_t dim) {
|
||||
|
@ -220,26 +219,26 @@ std::pair<CNTK::FunctionPtr, CNTK::FunctionPtr> LSTMPCellWithSelfStabilization(C
|
|||
auto stabilizedPrevOutput = Stabilize<ElementType>(prevOutput, device);
|
||||
auto stabilizedPrevCellState = Stabilize<ElementType>(prevCellState, device);
|
||||
|
||||
auto projectInput = [input, cellDim, inputDim, createBiasParam, createProjectionParam]() {
|
||||
return createBiasParam(cellDim) + CNTK::Times(createProjectionParam(cellDim, inputDim), input);
|
||||
auto projectInput = [input, cellDim, createBiasParam, createProjectionParam]() {
|
||||
return createBiasParam(cellDim) + CNTK::Times(createProjectionParam(cellDim), input);
|
||||
};
|
||||
|
||||
// Input gate
|
||||
auto it = CNTK::Sigmoid(projectInput() + CNTK::Times(createProjectionParam(cellDim, outputDim), stabilizedPrevOutput) + CNTK::ElementTimes(createDiagWeightParam(cellDim), stabilizedPrevCellState));
|
||||
auto bit = CNTK::ElementTimes(it, CNTK::Tanh(projectInput() + CNTK::Times(createProjectionParam(cellDim, outputDim), stabilizedPrevOutput)));
|
||||
auto it = CNTK::Sigmoid(projectInput() + CNTK::Times(createProjectionParam(cellDim), stabilizedPrevOutput) + CNTK::ElementTimes(createDiagWeightParam(cellDim), stabilizedPrevCellState));
|
||||
auto bit = CNTK::ElementTimes(it, CNTK::Tanh(projectInput() + CNTK::Times(createProjectionParam(cellDim), stabilizedPrevOutput)));
|
||||
|
||||
// Forget-me-not gate
|
||||
auto ft = CNTK::Sigmoid(projectInput() + CNTK::Times(createProjectionParam(cellDim, outputDim), stabilizedPrevOutput) + CNTK::ElementTimes(createDiagWeightParam(cellDim), stabilizedPrevCellState));
|
||||
auto ft = CNTK::Sigmoid(projectInput() + CNTK::Times(createProjectionParam(cellDim), stabilizedPrevOutput) + CNTK::ElementTimes(createDiagWeightParam(cellDim), stabilizedPrevCellState));
|
||||
auto bft = CNTK::ElementTimes(ft, prevCellState);
|
||||
|
||||
auto ct = bft + bit;
|
||||
|
||||
// Output gate
|
||||
auto ot = CNTK::Sigmoid(projectInput() + CNTK::Times(createProjectionParam(cellDim, outputDim), stabilizedPrevOutput) + CNTK::ElementTimes(createDiagWeightParam(cellDim), Stabilize<ElementType>(ct, device)));
|
||||
auto ot = CNTK::Sigmoid(projectInput() + CNTK::Times(createProjectionParam(cellDim), stabilizedPrevOutput) + CNTK::ElementTimes(createDiagWeightParam(cellDim), Stabilize<ElementType>(ct, device)));
|
||||
auto ht = CNTK::ElementTimes(ot, CNTK::Tanh(ct));
|
||||
|
||||
auto c = ct;
|
||||
auto h = (outputDim != cellDim) ? CNTK::Times(createProjectionParam(outputDim, cellDim), Stabilize<ElementType>(ht, device)) : ht;
|
||||
auto h = (outputDim != cellDim) ? CNTK::Times(createProjectionParam(outputDim), Stabilize<ElementType>(ht, device)) : ht;
|
||||
|
||||
return{ h, c };
|
||||
}
|
||||
|
|
|
@ -99,18 +99,14 @@ void TestReduceSum(size_t sampleRank, const DeviceDescriptor& device)
|
|||
|
||||
// Test ReduceSum along a dynamic axis
|
||||
{
|
||||
auto testReduceSum = [&sequences, &sequenceLengths, inputShape, sequencesValue, device](const Axis& axis)
|
||||
auto testReduceSum = [&sequences, &sequenceLengths, inputShape, sequencesValue, device]()
|
||||
{
|
||||
if (!axis.IsDynamicAxis())
|
||||
RuntimeError("Called the dynamic axis ReduceSum test with a static axis");
|
||||
|
||||
size_t maxActualSequenceLength = sequencesValue->Shape()[inputShape.Rank()];
|
||||
size_t numSequences = sequencesValue->Shape()[inputShape.Rank() + 1];
|
||||
|
||||
auto inputVar = InputVariable({ inputShape }, DataType::Float, L"input");
|
||||
FunctionPtr reduceSumFunc = ReduceSum(inputVar, axis);
|
||||
FunctionPtr reduceSumFunc = Sequence::ReduceSum(inputVar);
|
||||
|
||||
NDShape maskShape = { ((axis == Axis::DefaultBatchAxis()) ? maxActualSequenceLength : 1), ((axis == Axis::DefaultBatchAxis()) ? 1 : numSequences) };
|
||||
NDShape maskShape = { 1, numSequences };
|
||||
NDShape outputShape = reduceSumFunc->Output().Shape();
|
||||
auto outputDataShape = outputShape.AppendShape(maskShape);
|
||||
|
||||
|
@ -130,10 +126,7 @@ void TestReduceSum(size_t sampleRank, const DeviceDescriptor& device)
|
|||
for (size_t k = 0; k < inputShape.TotalSize(); ++k)
|
||||
{
|
||||
float value = sequences[i][(j * inputShape.TotalSize()) + k];
|
||||
if (axis == Axis::DefaultBatchAxis())
|
||||
expectedTotals[(j * inputShape.TotalSize()) + k] += value;
|
||||
else
|
||||
expectedTotals[(i * inputShape.TotalSize()) + k] += value;
|
||||
expectedTotals[(i * inputShape.TotalSize()) + k] += value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -141,7 +134,7 @@ void TestReduceSum(size_t sampleRank, const DeviceDescriptor& device)
|
|||
FloatingPointVectorCompare(outputData, expectedTotals, "testReduceSum: Forward prop results do not match expected results");
|
||||
};
|
||||
|
||||
testReduceSum(Axis::DefaultDynamicAxis());
|
||||
testReduceSum();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -217,11 +210,8 @@ void TestSlice(size_t sampleRank, const DeviceDescriptor& device)
|
|||
|
||||
// Test slice along a dynamic axis
|
||||
{
|
||||
auto testDynamicAxisSlice = [&sequences, &sequenceLengths, inputShape, sequencesValue, device](const Axis& axis, int beginOffset, int endOffset)
|
||||
auto testDynamicAxisSlice = [&sequences, &sequenceLengths, inputShape, sequencesValue, device](int beginOffset, int endOffset)
|
||||
{
|
||||
if (!axis.IsDynamicAxis())
|
||||
RuntimeError("Called the dynamic axis slice test with a static axis");
|
||||
|
||||
size_t maxActualSequenceLength = sequencesValue->Shape()[inputShape.Rank()];
|
||||
size_t numSequences = sequencesValue->Shape()[inputShape.Rank() + 1];
|
||||
|
||||
|
@ -229,11 +219,11 @@ void TestSlice(size_t sampleRank, const DeviceDescriptor& device)
|
|||
size_t maxSliceLength = (endAndBeginOffsetDiff > 0) ? endAndBeginOffsetDiff : maxActualSequenceLength + endAndBeginOffsetDiff;
|
||||
|
||||
auto inputVar = InputVariable(inputShape, DataType::Float, L"input");
|
||||
auto sliceFunc = Slice(inputVar, axis, beginOffset, endOffset);
|
||||
auto sliceFunc = Sequence::Slice(inputVar, beginOffset, endOffset);
|
||||
sliceFunc = sliceFunc + sliceFunc;
|
||||
|
||||
size_t outputSequenceAxisLength = (axis == Axis::DefaultDynamicAxis()) ? maxSliceLength : maxActualSequenceLength;
|
||||
size_t outputBatchAxisLength = (axis == Axis::DefaultBatchAxis()) ? maxSliceLength : numSequences;
|
||||
size_t outputSequenceAxisLength = maxSliceLength;
|
||||
size_t outputBatchAxisLength = numSequences;
|
||||
NDShape outputShape = sliceFunc->Output().Shape().AppendShape({ outputSequenceAxisLength, outputBatchAxisLength });
|
||||
std::vector<float> outputData(outputShape.TotalSize(), 0);
|
||||
NDMaskPtr mask;
|
||||
|
@ -247,15 +237,15 @@ void TestSlice(size_t sampleRank, const DeviceDescriptor& device)
|
|||
std::unordered_map<Variable, ValuePtr> outputs = { { sliceFunc->Output(), outputValue } };
|
||||
sliceFunc->Forward({ { inputVar, sequencesValue } }, outputs, device);
|
||||
|
||||
size_t startSequenceIdx = (axis == Axis::DefaultBatchAxis()) ? ((beginOffset >= 0) ? beginOffset : (numSequences + beginOffset)) : 0;
|
||||
size_t endSequenceIdx = (axis == Axis::DefaultBatchAxis()) ? ((endOffset > 0) ? endOffset : (numSequences + endOffset)) : numSequences;
|
||||
size_t startSequenceIdx = 0;
|
||||
size_t endSequenceIdx = numSequences;
|
||||
|
||||
std::vector<float> expectedOutputValues(inputShape.TotalSize() * outputSequenceAxisLength * outputBatchAxisLength);
|
||||
for (size_t i = startSequenceIdx; i < endSequenceIdx; ++i)
|
||||
{
|
||||
size_t currentSequenceLength = sequenceLengths[i];
|
||||
size_t startFrameIdx = (axis == Axis::DefaultDynamicAxis()) ? ((beginOffset >= 0) ? beginOffset : (currentSequenceLength + beginOffset)) : 0;
|
||||
size_t endFrameIdx = (axis == Axis::DefaultDynamicAxis()) ? ((endOffset > 0) ? endOffset : (currentSequenceLength + endOffset)) : currentSequenceLength;
|
||||
size_t startFrameIdx = ((beginOffset >= 0) ? beginOffset : (currentSequenceLength + beginOffset));
|
||||
size_t endFrameIdx = ((endOffset > 0) ? endOffset : (currentSequenceLength + endOffset));
|
||||
size_t j = startFrameIdx;
|
||||
for (; j < endFrameIdx; ++j)
|
||||
{
|
||||
|
@ -272,12 +262,12 @@ void TestSlice(size_t sampleRank, const DeviceDescriptor& device)
|
|||
FloatingPointVectorCompare(outputData, expectedOutputValues, "testDynamicAxisSlice: Forward prop results do not match expected results");
|
||||
};
|
||||
|
||||
testDynamicAxisSlice(Axis::DefaultDynamicAxis(), 0, 1);
|
||||
testDynamicAxisSlice(Axis::DefaultDynamicAxis(), 0, 2);
|
||||
testDynamicAxisSlice(Axis::DefaultDynamicAxis(), -1, 0);
|
||||
testDynamicAxisSlice(Axis::DefaultDynamicAxis(), -2, 0);
|
||||
testDynamicAxisSlice(Axis::DefaultDynamicAxis(), 0, -1);
|
||||
testDynamicAxisSlice(Axis::DefaultDynamicAxis(), 1, 0);
|
||||
testDynamicAxisSlice(0, 1);
|
||||
testDynamicAxisSlice(0, 2);
|
||||
testDynamicAxisSlice(-1, 0);
|
||||
testDynamicAxisSlice(-2, 0);
|
||||
testDynamicAxisSlice(0, -1);
|
||||
testDynamicAxisSlice(1, 0);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ using namespace CNTK;
|
|||
|
||||
using namespace std::placeholders;
|
||||
|
||||
void TrainSequenceToSequenceTranslator(const DeviceDescriptor& device, bool useSparseInputs, bool testSaveAndReLoad, bool testCheckpointing, bool addBeamSearchReorderingHook, bool testCloning)
|
||||
void TrainSequenceToSequenceTranslator(const DeviceDescriptor& device, bool useSparseInputs, bool testSaveAndReLoad, bool testCheckpointing, bool addBeamSearchReorderingHook, bool testCloning, bool usePlaceholders)
|
||||
{
|
||||
using namespace std::placeholders;
|
||||
|
||||
|
@ -30,7 +30,7 @@ void TrainSequenceToSequenceTranslator(const DeviceDescriptor& device, bool useS
|
|||
FunctionPtr inputSequence = Alias(rawInput, L"inputSequence");
|
||||
|
||||
// Drop the sentence start token from the label, for decoder training
|
||||
auto labelSequence = Slice(rawLabels, labelDynamicAxes[0], 1, 0, L"labelSequenceWithStartTrimmed");
|
||||
auto labelSequence = Sequence::Slice(rawLabels, 1, 0, L"labelSequenceWithStartTrimmed");
|
||||
auto labelSentenceStart = Sequence::First(rawLabels, L"labelSequenceStart");
|
||||
|
||||
auto isFirstLabel = Sequence::IsFirst(labelSequence, L"isFirstLabel");
|
||||
|
@ -38,8 +38,8 @@ void TrainSequenceToSequenceTranslator(const DeviceDescriptor& device, bool useS
|
|||
bool forceEmbedding = useSparseInputs;
|
||||
|
||||
/* Embeddings */
|
||||
auto inputEmbeddingWeights = Parameter({ inputEmbeddingDim, inputVocabDim }, DataType::Float, GlorotUniformInitializer(), device, L"inputEmbeddingWeights");
|
||||
auto labelEmbeddingWeights = Parameter({ labelEmbeddingDim, labelVocabDim }, DataType::Float, GlorotUniformInitializer(), device, L"labelEmbeddingWeights");
|
||||
auto inputEmbeddingWeights = Parameter({ inputEmbeddingDim, NDShape::InferredDimension }, DataType::Float, GlorotUniformInitializer(), device, L"inputEmbeddingWeights");
|
||||
auto labelEmbeddingWeights = Parameter({ labelEmbeddingDim, NDShape::InferredDimension }, DataType::Float, GlorotUniformInitializer(), device, L"labelEmbeddingWeights");
|
||||
|
||||
auto inputEmbedding = Alias((!forceEmbedding && (inputVocabDim <= inputEmbeddingDim)) ? inputSequence : Times(inputEmbeddingWeights, inputSequence), L"inputEmbedding");
|
||||
auto labelEmbedding = Alias((!forceEmbedding && (labelVocabDim <= labelEmbeddingDim)) ? labelSequence : Times(labelEmbeddingWeights, labelSequence), L"labelEmbedding");
|
||||
|
@ -63,8 +63,20 @@ void TrainSequenceToSequenceTranslator(const DeviceDescriptor& device, bool useS
|
|||
labelSentenceStartEmbeddedScattered = Reshape(labelSentenceStartEmbeddedScattered, labelSentenceStartEmbeddedScattered->Output().Shape().AppendShape({ 1 }), L"labelSentenceStartEmbeddedScattered");
|
||||
}
|
||||
|
||||
auto thoughtVectorBroadcastH = Sequence::BroadcastAs(thoughtVectorH, labelEmbedding, L"thoughtVectorBroadcastH");
|
||||
auto thoughtVectorBroadcastC = Sequence::BroadcastAs(thoughtVectorC, labelEmbedding, L"thoughtVectorBroadcastC");
|
||||
auto actualThoughtVectorBroadcastH = Sequence::BroadcastAs(thoughtVectorH, labelEmbedding, L"thoughtVectorBroadcastH");
|
||||
auto actualThoughtVectorBroadcastC = Sequence::BroadcastAs(thoughtVectorC, labelEmbedding, L"thoughtVectorBroadcastC");
|
||||
|
||||
Variable thoughtVectorBroadcastH, thoughtVectorBroadcastC;
|
||||
if (usePlaceholders)
|
||||
{
|
||||
thoughtVectorBroadcastH = PlaceholderVariable();
|
||||
thoughtVectorBroadcastC = PlaceholderVariable();
|
||||
}
|
||||
else
|
||||
{
|
||||
thoughtVectorBroadcastH = actualThoughtVectorBroadcastH;
|
||||
thoughtVectorBroadcastC = actualThoughtVectorBroadcastC;
|
||||
}
|
||||
|
||||
/* Decoder */
|
||||
auto beamSearchReorderHook = Constant({ 1, 1 }, 1.0f, device);
|
||||
|
@ -116,6 +128,10 @@ void TrainSequenceToSequenceTranslator(const DeviceDescriptor& device, bool useS
|
|||
auto biasWeights = Parameter({ labelVocabDim }, 0.0f, device);
|
||||
|
||||
auto z = Plus(Times(outputLayerProjWeights, Stabilize<float>(decoderOutput, device)), biasWeights, L"classifierOutput");
|
||||
|
||||
if (usePlaceholders)
|
||||
z->ReplacePlaceholders({ { thoughtVectorBroadcastH, actualThoughtVectorBroadcastH }, { thoughtVectorBroadcastC, actualThoughtVectorBroadcastC } });
|
||||
|
||||
auto ce = CrossEntropyWithSoftmax(z, labelSequence, L"lossFunction");
|
||||
auto errs = ClassificationError(z, labelSequence, L"classificationError");
|
||||
|
||||
|
@ -218,8 +234,8 @@ void TrainSequenceToSequenceTranslator()
|
|||
fprintf(stderr, "\nTrainSequenceToSequenceTranslator..\n");
|
||||
|
||||
// TODO: Also test with sparse input variables in the graph
|
||||
TrainSequenceToSequenceTranslator(DeviceDescriptor::CPUDevice(), false, true, false, true, true);
|
||||
TrainSequenceToSequenceTranslator(DeviceDescriptor::CPUDevice(), false, true, false, false, true, true);
|
||||
|
||||
if (IsGPUAvailable())
|
||||
TrainSequenceToSequenceTranslator(DeviceDescriptor::GPUDevice(0), false, false, true, false, false);
|
||||
TrainSequenceToSequenceTranslator(DeviceDescriptor::GPUDevice(0), false, false, true, true, false, false);
|
||||
}
|
||||
|
|
|
@ -19,6 +19,8 @@
|
|||
%rename(gpu_device) CNTK::DeviceDescriptor::GPUDevice;
|
||||
%rename(cpu_device) CNTK::DeviceDescriptor::CPUDevice;
|
||||
%rename(times_transpose) CNTK::TransposeTimes;
|
||||
%rename(sequence_slice) CNTK::Sequence::Slice;
|
||||
%rename(sequence_reduce_sum) CNTK::Sequence::ReduceSum;
|
||||
|
||||
%rename(momentum_as_time_constant_schedule) CNTK::MomentumAsTimeConstantSchedule;
|
||||
|
||||
|
|
|
@ -63,6 +63,28 @@ def is_last(seq, name=''):
|
|||
seq = sanitize_input(seq, get_data_type(seq))
|
||||
return is_last(seq, name)
|
||||
|
||||
@typemap
|
||||
def slice(seq, begin_index, end_index, name=''):
|
||||
'''
|
||||
Slice the input sequence.
|
||||
|
||||
Examples:
|
||||
TBA
|
||||
Args:
|
||||
seq: sequence input tensor
|
||||
begin_index (`int`): the index along sequence axis where the slicing starts
|
||||
end_index (`int`): the index along sequence axis where the slicing ends
|
||||
name (`str`, optional): the name of the Function instance in the network
|
||||
|
||||
See also:
|
||||
Indexing in NumPy: http://docs.scipy.org/doc/numpy/reference/arrays.indexing.html
|
||||
|
||||
Returns:
|
||||
:class:`cntk.ops.functions.Function`
|
||||
'''
|
||||
from cntk.cntk_py import sequence_slice
|
||||
seq = sanitize_input(seq, get_data_type(seq))
|
||||
return sequence_slice(seq, begin_index, end_index, name)
|
||||
|
||||
@typemap
|
||||
def first(seq, name=''):
|
||||
|
@ -281,3 +303,21 @@ def broadcast_as(operand, broadcast_as_operand, name=''):
|
|||
broadcast_as_operand = sanitize_input(
|
||||
broadcast_as_operand, get_data_type(broadcast_as_operand))
|
||||
return broadcast_as(operand, broadcast_as_operand, name)
|
||||
|
||||
@typemap
|
||||
def reduce_sum(seq, name=''):
|
||||
'''
|
||||
Computes the sum of the input sequence's elements across the sequence axis.
|
||||
|
||||
Examples:
|
||||
TBA
|
||||
Args:
|
||||
seq: sequence input tensor
|
||||
name (`str`, optional): the name of the Function instance in the network
|
||||
|
||||
Returns:
|
||||
:class:`cntk.ops.functions.Function`
|
||||
'''
|
||||
from cntk.cntk_py import sequence_reduce_sum
|
||||
seq = sanitize_input(seq, get_data_type(seq))
|
||||
return sequence_reduce_sum(seq, name)
|
||||
|
|
|
@ -166,7 +166,7 @@ def test_op_slice_sequence(input_data, slice_params, expected_result, device_id,
|
|||
dynamic_axes=[Axis.default_batch_axis(), t],
|
||||
name='a')
|
||||
|
||||
result = C.slice(a, axis=t, begin_index=slice_params[
|
||||
result = C.sequence.slice(a, begin_index=slice_params[
|
||||
0], end_index=slice_params[1])
|
||||
|
||||
def grad_slice(x, beg_index, end_index):
|
||||
|
|
|
@ -11,7 +11,7 @@ from cntk import Trainer, Axis, save_model, load_model #, text_format_minibatch_
|
|||
from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs, INFINITELY_REPEAT, FULL_DATA_SWEEP
|
||||
from cntk.device import cpu, set_default_device
|
||||
from cntk.learner import momentum_sgd, momentum_as_time_constant_schedule
|
||||
from cntk.ops import input_variable, cross_entropy_with_softmax, classification_error, sequence, slice, past_value, future_value, element_select, alias, hardmax
|
||||
from cntk.ops import input_variable, cross_entropy_with_softmax, classification_error, sequence, past_value, future_value, element_select, alias, hardmax
|
||||
from cntk.ops.functions import CloneMethod
|
||||
|
||||
abs_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
@ -94,7 +94,7 @@ def sequence_to_sequence_translator(debug_output=False, run_test=False):
|
|||
input_sequence = raw_input
|
||||
|
||||
# Drop the sentence start token from the label, for decoder training
|
||||
label_sequence = slice(raw_labels, label_seq_axis, 1, 0) # <s> A B C </s> --> A B C </s>
|
||||
label_sequence = sequence.slice(raw_labels, 1, 0) # <s> A B C </s> --> A B C </s>
|
||||
label_sentence_start = sequence.first(raw_labels) # <s>
|
||||
|
||||
is_first_label = sequence.is_first(label_sequence) # <s> 0 0 0 ...
|
||||
|
@ -239,7 +239,7 @@ def sequence_to_sequence_translator(debug_output=False, run_test=False):
|
|||
z = load_model("seq2seq.dnn")
|
||||
|
||||
label_seq_axis = Axis('labelAxis')
|
||||
label_sequence = slice(find_arg_by_name('raw_labels',z), label_seq_axis, 1, 0)
|
||||
label_sequence = sequence.slice(find_arg_by_name('raw_labels',z), 1, 0)
|
||||
ce = cross_entropy_with_softmax(z, label_sequence)
|
||||
errs = classification_error(z, label_sequence)
|
||||
trainer = Trainer(z, ce, errs, [momentum_sgd(
|
||||
|
|
|
@ -10,11 +10,11 @@ from cntk import Trainer, Axis #, text_format_minibatch_source, StreamConfigurat
|
|||
from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs, INFINITELY_REPEAT, FULL_DATA_SWEEP
|
||||
from cntk.device import cpu, set_default_device
|
||||
from cntk.learner import sgd
|
||||
from cntk.ops import input_variable, cross_entropy_with_softmax, classification_error
|
||||
from cntk.ops import input_variable, cross_entropy_with_softmax, classification_error, sequence
|
||||
|
||||
abs_path = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(os.path.join(abs_path, "..", ".."))
|
||||
from examples.common.nn import LSTMP_component_with_self_stabilization, embedding, linear_layer, select_last, print_training_progress
|
||||
from examples.common.nn import LSTMP_component_with_self_stabilization, embedding, linear_layer, print_training_progress
|
||||
|
||||
# Creates the reader
|
||||
def create_reader(path, is_training, input_dim, label_dim):
|
||||
|
@ -28,7 +28,7 @@ def LSTM_sequence_classifer_net(input, num_output_classes, embedding_dim, LSTM_d
|
|||
embedding_function = embedding(input, embedding_dim)
|
||||
LSTM_function = LSTMP_component_with_self_stabilization(
|
||||
embedding_function.output, LSTM_dim, cell_dim)[0]
|
||||
thought_vector = select_last(LSTM_function)
|
||||
thought_vector = sequence.last(LSTM_function)
|
||||
|
||||
return linear_layer(thought_vector, num_output_classes)
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче