diff --git a/Source/CNTKv2LibraryDll/API/CNTKLibrary.h b/Source/CNTKv2LibraryDll/API/CNTKLibrary.h index 88071578f..23b35aa88 100644 --- a/Source/CNTKv2LibraryDll/API/CNTKLibrary.h +++ b/Source/CNTKv2LibraryDll/API/CNTKLibrary.h @@ -2399,6 +2399,11 @@ namespace CNTK /// CNTK_API static FunctionPtr LoadModel(DataType dataType, const std::wstring& modelFile, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice()); + /// + /// Prints the entire graph underlying this function to stderr + /// + CNTK_API void PrintGraph() const; + private: template @@ -2899,6 +2904,13 @@ namespace CNTK CNTK_API FunctionPtr IsFirst(const Variable& operand, const std::wstring& name = L""); CNTK_API FunctionPtr IsLast(const Variable& operand, const std::wstring& name = L""); + CNTK_API FunctionPtr Slice(const Variable& operand, int beginIndex, int endIndex, const std::wstring& name = L""); + + /// + /// Create an instance of the CNTK built-in sum reduction operation on specified tensor input operand along the operands lone dynamic sequence axis + /// + CNTK_API FunctionPtr ReduceSum(const Variable& operand, const std::wstring& name = L""); + CNTK_API FunctionPtr First(const Variable& operand, const std::wstring& name = L""); CNTK_API FunctionPtr Last(const Variable& operand, const std::wstring& name = L""); diff --git a/Source/CNTKv2LibraryDll/API/CNTKLibraryInternals.h b/Source/CNTKv2LibraryDll/API/CNTKLibraryInternals.h index 21dc46dc4..ecba971a1 100644 --- a/Source/CNTKv2LibraryDll/API/CNTKLibraryInternals.h +++ b/Source/CNTKv2LibraryDll/API/CNTKLibraryInternals.h @@ -206,9 +206,11 @@ namespace CNTK CNTK_API FunctionPtr GatherPacked(const Variable& operand, const Variable& packedIndex, const std::wstring& name = L""); CNTK_API FunctionPtr ScatterPacked(const Variable& operand, const Variable& packedIndex, const Variable& condition, const std::wstring& name = L""); CNTK_API FunctionPtr ZeroesWithDynamicAxesLike(const Variable& operand); - CNTK_API FunctionPtr Where(const Variable& condition, const std::vector& newDynamicAxes, const std::wstring& name = L""); - CNTK_API FunctionPtr Gather(const Variable& operand, const Variable& condition, const std::vector& newDynamicAxes, const std::wstring& name = L""); - CNTK_API FunctionPtr Scatter(const Variable& operand, const Variable& condition, const std::vector& newDynamicAxes, const std::wstring& name = L""); + CNTK_API FunctionPtr Where(const Variable& condition, const std::pair& newDerivedSequenceAxisScalingAndAdditiveFactor, const std::wstring& name = L""); + CNTK_API FunctionPtr Gather(const Variable& operand, const Variable& condition, const std::wstring& name = L""); + CNTK_API FunctionPtr Gather(const Variable& operand, const Variable& condition, const std::pair& newDerivedSequenceAxisScalingAndAdditiveFactor, const std::wstring& name = L""); + CNTK_API FunctionPtr Scatter(const Variable& operand, const Variable& condition, const std::wstring& name = L""); + CNTK_API FunctionPtr Scatter(const Variable& operand, const Variable& condition, const std::pair& newDerivedSequenceAxisScalingAndAdditiveFactor, const std::wstring& name = L""); CNTK_API FunctionPtr Slice(const Variable& operand, const Axis& axis, int beginIndex, int endIndex, const std::wstring& name = L""); CNTK_API FunctionPtr ReduceElements(const Variable& operand, const std::wstring& reductionOpName, const Axis& axis, const std::wstring& name = L""); diff --git a/Source/CNTKv2LibraryDll/Function.cpp b/Source/CNTKv2LibraryDll/Function.cpp index 47c8465d7..0a102212f 100644 --- a/Source/CNTKv2LibraryDll/Function.cpp +++ b/Source/CNTKv2LibraryDll/Function.cpp @@ -544,6 +544,12 @@ namespace CNTK return CompositeFunction::Deserialize(modelDictionary, device); } + void Function::PrintGraph() const + { + CompositeFunction::Traverse(RootFunction(), [](const FunctionPtr& function) { + }); + } + // Names for the reduction operations as used by the CNTK ReduceElementsNode /*static*/ const std::wstring PrimitiveFunction::InternalSumReductionOpName = L"Sum"; /*static*/ const std::wstring PrimitiveFunction::InternalLogSumReductionOpName = L"LogSum"; @@ -580,6 +586,8 @@ namespace CNTK /*static*/ const std::wstring PrimitiveFunction::AttributeNameEpsilon = L"epsilon"; /*static*/ const std::wstring PrimitiveFunction::AttributeNameUseCuDNNEngine = L"useCuDNNEngine"; /*static*/ const std::wstring PrimitiveFunction::AttributeNameNewDynamicAxes = L"newDynamicAxes"; + /*static*/ const std::wstring PrimitiveFunction::AttributeNameNewSequenceAxisLengthScalingFactor = L"newSequenceAxisLengthScalingFactor"; + /*static*/ const std::wstring PrimitiveFunction::AttributeNameNewSequenceAxisLengthAdditiveFactor = L"newSequenceAxisLengthAdditiveFactor"; /*static*/ const std::wstring PrimitiveFunction::AttributeNameBeginIndex = L"beginIndex"; /*static*/ const std::wstring PrimitiveFunction::AttributeNameEndIndex = L"endIndex"; /*static*/ const std::wstring PrimitiveFunction::AttributeNameReductionOpName = L"reductionOpName"; @@ -631,7 +639,36 @@ namespace CNTK if ((op == PrimitiveOpType::SumAll) || (op == PrimitiveOpType::SquaredError) || (op == PrimitiveOpType::CrossEntropyWithSoftmax) || (op == PrimitiveOpType::ClassificationError)) outputDynamicAxes = std::vector({}); else if (op == PrimitiveOpType::Where) - outputDynamicAxes = AsVector(functionConfig[PrimitiveFunction::AttributeNameNewDynamicAxes].Value>()); + { + if (functionConfig.Contains(PrimitiveFunction::AttributeNameNewDynamicAxes)) + outputDynamicAxes = AsVector(functionConfig[PrimitiveFunction::AttributeNameNewDynamicAxes].Value>()); + else + { + if (inputs[0].DynamicAxes() == Axis::UnknownDynamicAxes()) + outputDynamicAxes = Axis::UnknownDynamicAxes(); + else + { + if (functionConfig.Contains(PrimitiveFunction::AttributeNameNewSequenceAxisLengthScalingFactor) && + functionConfig.Contains(PrimitiveFunction::AttributeNameNewSequenceAxisLengthAdditiveFactor)) + { + size_t newSequenceAxisLengthScalingFactor = functionConfig[PrimitiveFunction::AttributeNameNewSequenceAxisLengthScalingFactor].Value(); + int newSequenceAxisLengthAdditiveFactor = functionConfig[PrimitiveFunction::AttributeNameNewSequenceAxisLengthAdditiveFactor].Value(); + + auto derivedDynamicAxes = GetDerivedDynamicAxes(inputs[0].DynamicAxes()[0], newSequenceAxisLengthScalingFactor, newSequenceAxisLengthAdditiveFactor); + std::copy(derivedDynamicAxes.begin(), derivedDynamicAxes.end(), std::back_inserter(outputDynamicAxes)); + } + else + { + outputDynamicAxes.push_back(Axis::NewUniqueDynamicAxis(L"whereNodeDynamicAxis")); + } + + for (size_t i = 1; i < inputs[0].DynamicAxes().size(); ++i) + outputDynamicAxes.push_back(inputs[0].DynamicAxes()[i]); + + functionConfig[PrimitiveFunction::AttributeNameNewDynamicAxes] = AsDictionaryValueVector(outputDynamicAxes); + } + } + } else if (op == PrimitiveOpType::ScatterPacked) outputDynamicAxes = inputs[2].DynamicAxes(); else if ((op == PrimitiveOpType::PackedIndex) || (op == PrimitiveOpType::GatherPacked)) @@ -1098,7 +1135,7 @@ namespace CNTK std::vector topoSortedPrimitiveFunctions; std::vector inputs; std::unordered_set inputUids; - Traverse([&visitedFunctions, &inputs, &topoSortedPrimitiveFunctions, &inputUids](const FunctionPtr& function) { + Traverse(RootFunction(), [&visitedFunctions, &inputs, &topoSortedPrimitiveFunctions, &inputUids](const FunctionPtr& function) { std::vector functionInputs = function->Inputs(); for (const auto& input : functionInputs) { @@ -2585,7 +2622,7 @@ namespace CNTK FunctionPtr Round(const Variable& operand, const std::wstring& name) { - return Floor(Plus(operand, Constant::Scalar(operand.GetDataType(), 0.5)), name); + return Floor(Plus(operand, Constant::Scalar(0.5f)), name); } FunctionPtr Floor(const Variable& operand, const std::wstring& name) @@ -2633,11 +2670,9 @@ namespace CNTK return TransposeAxes(operand, Axis(0), Axis(1), name); } + FunctionPtr Slice(const Variable& operand, const Axis& axis, int beginIndex, int endIndex, const std::wstring& name) { - if (axis == Axis::DefaultBatchAxis()) - LogicError("Slice is currently unsupported along the batch axis"); - if (axis.IsStaticAxis()) { if ((endIndex - beginIndex) <= 0) @@ -2646,46 +2681,10 @@ namespace CNTK return Internal::Slice(operand, axis, beginIndex, endIndex, name); } - if ((beginIndex == 0) && (endIndex == 0)) - return operand; + if (axis == Axis::DefaultBatchAxis()) + LogicError("Slice is currently unsupported along the batch axis"); - auto operandAxes = operand.DynamicAxes(); - auto findAxis = std::find(operandAxes.begin(), operandAxes.end(), axis); - if (findAxis == operandAxes.end()) - InvalidArgument("The specified dynamic axis named %S does not match any of the dynamic axes of the operand", axis.Name().c_str()); - - auto beginFlagsLambda = [beginIndex, operand]() { - return (beginIndex > 0) ? Minus(Constant::Scalar(operand.GetDataType(), 1.0), Internal::IsWithin(operand, beginIndex)) : Internal::IsWithin(operand, beginIndex); - }; - - auto endFlagsLambda = [endIndex, operand]() { - return (endIndex > 0) ? Internal::IsWithin(operand, endIndex) : Minus(Constant::Scalar(operand.GetDataType(), 1.0), Internal::IsWithin(operand, endIndex)); - }; - - FunctionPtr flags; - if (beginIndex == 0) - flags = endFlagsLambda(); - else if (endIndex == 0) - flags = beginFlagsLambda(); - else - flags = ElementTimes(beginFlagsLambda(), endFlagsLambda()); - - // Since we are slicing along a dynamic axis, the output variable's dynamic axes will be different than the operand - std::vector newDynamicAxes; - for (auto operandAxis : operandAxes) - { - if (operandAxis == axis) - { - int sliceLength = (endIndex - beginIndex); - size_t multiplicativeFactor = (sliceLength > 0) ? 0 : 1; - auto derivedDynamicAxes = GetDerivedDynamicAxes(operandAxis, multiplicativeFactor, sliceLength); - std::copy(derivedDynamicAxes.begin(), derivedDynamicAxes.end(), std::back_inserter(newDynamicAxes)); - } - else - newDynamicAxes.push_back(operandAxis); - } - - return Internal::Gather(operand, flags, newDynamicAxes, name); + LogicError("CNTK::Slice: Invalid axis argument provided. To slice a sequence along its ordered dynamic axis use Sequence::Slice."); } FunctionPtr RandomSample(const Variable& operand, size_t numSamples, bool allowDuplicates, const std::wstring& name) @@ -2721,6 +2720,7 @@ namespace CNTK return UnaryOp(PrimitiveOpType::Reshape, operand, std::move(additionalProperties), name); } + FunctionPtr BinaryOp(PrimitiveOpType op, const Variable& leftOperand, const Variable& rightOperand, Dictionary&& opConfig, const std::wstring& name) { std::vector operands = { leftOperand, rightOperand }; @@ -2815,14 +2815,14 @@ namespace CNTK if (topN == 1) { if (axis == Axis(0)) - return Minus(Constant::Scalar(prediction.GetDataType(), 1.0), TransposeTimes(labels, Hardmax(prediction)), name); + return Minus(Constant::Scalar(1.0f), TransposeTimes(labels, Hardmax(prediction)), name); else { auto axMax = ReduceMax(prediction, axis); auto pred = Equal(prediction, axMax); auto wrongPred = NotEqual(labels, pred); auto axErr = ReduceSum(wrongPred, axis); - auto capErr = GreaterEqual(axErr, Constant::Scalar(prediction.GetDataType(), 1.0)); + auto capErr = GreaterEqual(axErr, Constant::Scalar(1.0f)); return ReduceMean(capErr, Axis::AllStaticAxes(), name); } } @@ -2831,7 +2831,7 @@ namespace CNTK if (axis != Axis(0)) LogicError("ClassificationError along a specific axis does not support topN!"); - std::vector operands = { prediction, labels, Constant::Scalar(prediction.GetDataType(), (double)topN) }; + std::vector operands = { prediction, labels, Constant::Scalar((float)topN) }; return CompositeFunction::Create(MakeSharedObject(PrimitiveOpType::ClassificationError, operands, Dictionary(), name), name); } } @@ -3011,75 +3011,113 @@ namespace CNTK { // TODO: This is a temporary and expensive hack until we have a real alias implementation // that does not waste memory and compute cycles - return Plus(operand, Constant::Scalar(operand.GetDataType(), 0), name); + return Plus(operand, Constant::Scalar(0.0f), name); } namespace Sequence { void VerifyIsSequence(const Variable& operand) { - // The operand must have at least one dynamic axis and its first dynamic axis must be ordered - if (operand.DynamicAxes().empty() || !operand.DynamicAxes()[0].IsOrdered()) + // The operand must have at least one dynamic axis + if (operand.DynamicAxes().empty()) InvalidArgument("A sequence function can only be applied on operands with at least one dynamic axis and whose first dynamic axis is ordered"); } FunctionPtr IsFirst(const Variable& operand, const std::wstring& name) { - VerifyIsSequence(operand); return Internal::IsWithin(operand, 1, name); } FunctionPtr IsLast(const Variable& operand, const std::wstring& name) { - VerifyIsSequence(operand); return Internal::IsWithin(operand, -1, name); } + FunctionPtr Slice(const Variable& operand, int beginIndex, int endIndex, const std::wstring& name) + { + VerifyIsSequence(operand); + + if ((beginIndex == 0) && (endIndex == 0)) + return operand; + + auto beginFlagsLambda = [beginIndex, operand]() { + return (beginIndex > 0) ? Minus(Constant::Scalar(1.0f), Internal::IsWithin(operand, beginIndex)) : Internal::IsWithin(operand, beginIndex); + }; + + auto endFlagsLambda = [endIndex, operand]() { + return (endIndex > 0) ? Internal::IsWithin(operand, endIndex) : Minus(Constant::Scalar(1.0f), Internal::IsWithin(operand, endIndex)); + }; + + FunctionPtr flags; + if (beginIndex == 0) + flags = endFlagsLambda(); + else if (endIndex == 0) + flags = beginFlagsLambda(); + else + flags = ElementTimes(beginFlagsLambda(), endFlagsLambda()); + + int sliceLength = (endIndex - beginIndex); + size_t multiplicativeFactor = (sliceLength > 0) ? 0 : 1; + + return Internal::Gather(operand, flags, { multiplicativeFactor, sliceLength }, name); + } + FunctionPtr First(const Variable& operand, const std::wstring& name) { - VerifyIsSequence(operand); - return Slice(operand, operand.DynamicAxes()[0], 0, 1, name); + return Sequence::Slice(operand, 0, 1, name); } FunctionPtr Last(const Variable& operand, const std::wstring& name) { - VerifyIsSequence(operand); - return Slice(operand, operand.DynamicAxes()[0], -1, 0, name); - } - - std::vector WhereOpDynamicAxes(const Variable& operand) - { - VerifyIsSequence(operand); - - std::vector newDynamicAxes = { Axis::NewUniqueDynamicAxis(L"whereNodeDynamicAxis") }; - for (size_t i = 1; i < operand.DynamicAxes().size(); ++i) - newDynamicAxes.push_back(operand.DynamicAxes()[i]); - - return newDynamicAxes; + return Sequence::Slice(operand, -1, 0, name); } FunctionPtr Where(const Variable& condition, const std::wstring& name) { - return Internal::Where(condition, WhereOpDynamicAxes(condition), name); + return UnaryOp(PrimitiveOpType::Where, condition, Dictionary(), name); } FunctionPtr Gather(const Variable& operand, const Variable& condition, const std::wstring& name) { - return Internal::Gather(operand, condition, WhereOpDynamicAxes(condition), name); + return Internal::Gather(operand, condition, name); } FunctionPtr Scatter(const Variable& operand, const Variable& condition, const std::wstring& name) { - return Internal::Scatter(operand, condition, WhereOpDynamicAxes(condition), name); + return Internal::Scatter(operand, condition, name); } FunctionPtr BroadcastAs(const Variable& operand, const Variable& broadcastAs, const std::wstring& name) { - auto dataPadded = Internal::Scatter(operand, Sequence::IsFirst(broadcastAs), operand.DynamicAxes()); + auto dataPadded = Internal::Scatter(operand, Sequence::IsFirst(broadcastAs), std::make_pair(0, 1)); auto placeHolderOutput = PlaceholderVariable(operand.Shape(), broadcastAs.DynamicAxes()); auto output = ElementSelect(Sequence::IsFirst(broadcastAs), dataPadded, PastValue(placeHolderOutput), name); return output->ReplacePlaceholders({ { placeHolderOutput, output } }); } + + FunctionPtr ReduceElements(const Variable& operand, const std::wstring& reductionOpName, const std::wstring& name) + { + using namespace std::placeholders; + + std::function reductionFunctor; + if (reductionOpName == PrimitiveFunction::InternalSumReductionOpName) + reductionFunctor = std::bind(Plus, _1, _2, L""); + else + LogicError("%S reduction along dynamic axis is currently unsupported", reductionOpName.c_str()); + + // We are reducing over a dynamic axis which is currently implemented using recurrence + auto cumulativeSumFunctionPlaceholder = PlaceholderVariable(operand.Shape()); + auto prevAccumulatedValuesFunction = PastValue(cumulativeSumFunctionPlaceholder); + auto cumulativeSumFunction = reductionFunctor(prevAccumulatedValuesFunction, operand); + cumulativeSumFunction->ReplacePlaceholders({ { cumulativeSumFunctionPlaceholder, cumulativeSumFunction } }); + + return Sequence::Slice(cumulativeSumFunction, -1, 0, name); + } + + FunctionPtr ReduceSum(const Variable& operand, const std::wstring& name) + { + return ReduceElements(operand, PrimitiveFunction::InternalSumReductionOpName, name); + } } namespace Internal @@ -3092,9 +3130,9 @@ namespace CNTK InvalidArgument("CNTK::Sequence::IsWithin: The offset must be positive"); if (offset > 0) - return PastValue(Internal::ZeroesWithDynamicAxesLike(operand), Constant::Scalar(operand.GetDataType(), 1.0), offset, name); + return PastValue(Internal::ZeroesWithDynamicAxesLike(operand), Constant::Scalar(1.0f), offset, name); else - return FutureValue(Internal::ZeroesWithDynamicAxesLike(operand), Constant::Scalar(operand.GetDataType(), 1.0), -offset, name); + return FutureValue(Internal::ZeroesWithDynamicAxesLike(operand), Constant::Scalar(1.0f), -offset, name); } FunctionPtr PackedIndex(const Variable& operand, const Variable& index, const std::wstring& name) @@ -3131,21 +3169,32 @@ namespace CNTK } } - FunctionPtr Where(const Variable& condition, const std::vector& newDynamicAxes, const std::wstring& name) + FunctionPtr Where(const Variable& condition, const std::pair& newDerivedSequenceAxisScalingAndAdditiveFactor, const std::wstring& name) { auto additionalProperties = Dictionary(); - additionalProperties[PrimitiveFunction::AttributeNameNewDynamicAxes] = AsDictionaryValueVector(newDynamicAxes); + additionalProperties[PrimitiveFunction::AttributeNameNewSequenceAxisLengthScalingFactor] = newDerivedSequenceAxisScalingAndAdditiveFactor.first; + additionalProperties[PrimitiveFunction::AttributeNameNewSequenceAxisLengthAdditiveFactor] = newDerivedSequenceAxisScalingAndAdditiveFactor.second; return UnaryOp(PrimitiveOpType::Where, condition, std::move(additionalProperties), name); } - FunctionPtr Gather(const Variable& operand, const Variable& condition, const std::vector& newDynamicAxes, const std::wstring& name) + FunctionPtr Gather(const Variable& operand, const Variable& condition, const std::wstring& name) { - return Internal::GatherPacked(operand, Internal::PackedIndex(/*layout of*/ operand, Where(condition, newDynamicAxes)), name); + return Internal::GatherPacked(operand, Internal::PackedIndex(/*layout of*/ operand, Sequence::Where(condition)), name); } - FunctionPtr Scatter(const Variable& operand, const Variable& condition, const std::vector& whereNodeDynamicAxes, const std::wstring& name) + FunctionPtr Gather(const Variable& operand, const Variable& condition, const std::pair& newDerivedSequenceAxisScalingAndAdditiveFactor, const std::wstring& name) { - return Internal::ScatterPacked(operand, Internal::PackedIndex(/*layout of*/ condition, Where(condition, whereNodeDynamicAxes)), /*layout of*/ condition, name); + return Internal::GatherPacked(operand, Internal::PackedIndex(/*layout of*/ operand, Where(condition, newDerivedSequenceAxisScalingAndAdditiveFactor)), name); + } + + FunctionPtr Scatter(const Variable& operand, const Variable& condition, const std::wstring& name) + { + return Internal::ScatterPacked(operand, Internal::PackedIndex(/*layout of*/ condition, Sequence::Where(condition)), /*layout of*/ condition, name); + } + + FunctionPtr Scatter(const Variable& operand, const Variable& condition, const std::pair& newDerivedSequenceAxisScalingAndAdditiveFactor, const std::wstring& name) + { + return Internal::ScatterPacked(operand, Internal::PackedIndex(/*layout of*/ condition, Where(condition, newDerivedSequenceAxisScalingAndAdditiveFactor)), /*layout of*/ condition, name); } FunctionPtr Slice(const Variable& operand, const Axis& axis, int beginIndex, int endIndex, const std::wstring& name) @@ -3160,8 +3209,6 @@ namespace CNTK FunctionPtr ReduceElements(const Variable& operand, const std::wstring& reductionOpName, const Axis& axis, const std::wstring& name) { - using namespace std::placeholders; - if (axis.IsStaticAxis() || (axis == Axis::AllStaticAxes())) { auto additionalProperties = Dictionary(); @@ -3173,20 +3220,7 @@ namespace CNTK if (axis == Axis::DefaultBatchAxis()) LogicError("Reduction is currently unsupported along the batch axis"); - if (reductionOpName != PrimitiveFunction::InternalSumReductionOpName) - LogicError("%S reduction along dynamic axis is currently unsupported", reductionOpName.c_str()); - - std::function reductionFunctor; - if (reductionOpName == PrimitiveFunction::InternalSumReductionOpName) - reductionFunctor = std::bind(Plus, _1, _2, L""); - - // We are reducing over a dynamic axis which is currently implemented using recurrence - auto cumulativeSumFunctionPlaceholder = PlaceholderVariable(operand.Shape()); - auto prevAccumulatedValuesFunction = PastValue(cumulativeSumFunctionPlaceholder); - auto cumulativeSumFunction = reductionFunctor(prevAccumulatedValuesFunction, operand); - cumulativeSumFunction->ReplacePlaceholders({ { cumulativeSumFunctionPlaceholder, cumulativeSumFunction } }); - - return CNTK::Slice(cumulativeSumFunction, axis, -1, 0, name); + LogicError("CNTK::ReduceElements: Invalid axis argument provided. To reduce a sequence along its ordered dynamic axis use Sequence::ReduceElements."); } } } diff --git a/Source/CNTKv2LibraryDll/Function.h b/Source/CNTKv2LibraryDll/Function.h index fc5fe5cd8..2702a6131 100644 --- a/Source/CNTKv2LibraryDll/Function.h +++ b/Source/CNTKv2LibraryDll/Function.h @@ -187,6 +187,8 @@ namespace CNTK static const std::wstring AttributeNameEpsilon; static const std::wstring AttributeNameUseCuDNNEngine; static const std::wstring AttributeNameNewDynamicAxes; + static const std::wstring AttributeNameNewSequenceAxisLengthScalingFactor; + static const std::wstring AttributeNameNewSequenceAxisLengthAdditiveFactor; static const std::wstring AttributeNameBeginIndex; static const std::wstring AttributeNameEndIndex; static const std::wstring AttributeNameReductionOpName; @@ -699,22 +701,11 @@ namespace CNTK return CompositeFunctionOpName; } - private: - virtual void ReplacePlaceholdersInPlace(const std::unordered_map& placeholderReplacements, - std::unordered_set& visitedFunctions, - std::unordered_set& replacedPlaceholders) override; - - CompositeFunction(const FunctionPtr& rootFunction, std::unordered_set&& allPrimitiveFunctions, const std::wstring& name, const std::wstring& uid = Internal::GenerateUid(L"CompositeFunction")) - : Function({}, rootFunction->Outputs(), Dictionary(), rootFunction, name, uid), - m_allPrimitiveFunctions(std::move(allPrimitiveFunctions)), m_networkMatricesAllocated(false) - {} - template - void Traverse(const FunctionType& functor) const + static void Traverse(const FunctionPtr& rootFunction, const FunctionType& functor) { - const auto& root = RootFunction(); std::unordered_set visitedFunctions; - Traverse(root, visitedFunctions, functor); + Traverse(rootFunction, visitedFunctions, functor); } // Recursively traverses the Function graph underlying the 'rootFunction' invoking the provided functor for all visited nodes in the graph. @@ -735,6 +726,16 @@ namespace CNTK } } + private: + virtual void ReplacePlaceholdersInPlace(const std::unordered_map& placeholderReplacements, + std::unordered_set& visitedFunctions, + std::unordered_set& replacedPlaceholders) override; + + CompositeFunction(const FunctionPtr& rootFunction, std::unordered_set&& allPrimitiveFunctions, const std::wstring& name, const std::wstring& uid = Internal::GenerateUid(L"CompositeFunction")) + : Function({}, rootFunction->Outputs(), Dictionary(), rootFunction, name, uid), + m_allPrimitiveFunctions(std::move(allPrimitiveFunctions)), m_networkMatricesAllocated(false) + {} + std::vector DetermineInputs() const { const auto& root = RootFunction(); diff --git a/Tests/UnitTests/V2LibraryTests/Common.h b/Tests/UnitTests/V2LibraryTests/Common.h index 54009be18..f8cc584b8 100644 --- a/Tests/UnitTests/V2LibraryTests/Common.h +++ b/Tests/UnitTests/V2LibraryTests/Common.h @@ -200,7 +200,6 @@ inline CNTK::FunctionPtr Stabilize(const CNTK::Variable& x, const CNTK::DeviceDe template std::pair LSTMPCellWithSelfStabilization(CNTK::Variable input, CNTK::Variable prevOutput, CNTK::Variable prevCellState, const CNTK::DeviceDescriptor& device) { - size_t inputDim = input.Shape()[0]; size_t outputDim = prevOutput.Shape()[0]; size_t cellDim = prevCellState.Shape()[0]; @@ -209,8 +208,8 @@ std::pair LSTMPCellWithSelfStabilization(C }; unsigned long seed = 1; - auto createProjectionParam = [device, &seed](size_t outputDim, size_t inputDim) { - return CNTK::Parameter({ outputDim, inputDim }, CNTK::AsDataType(), CNTK::GlorotUniformInitializer(1, 0, 1, seed++), device); + auto createProjectionParam = [device, &seed](size_t outputDim) { + return CNTK::Parameter({ outputDim, CNTK::NDShape::InferredDimension }, CNTK::AsDataType(), CNTK::GlorotUniformInitializer(1, 0, 1, seed++), device); }; auto createDiagWeightParam = [device, &seed](size_t dim) { @@ -220,26 +219,26 @@ std::pair LSTMPCellWithSelfStabilization(C auto stabilizedPrevOutput = Stabilize(prevOutput, device); auto stabilizedPrevCellState = Stabilize(prevCellState, device); - auto projectInput = [input, cellDim, inputDim, createBiasParam, createProjectionParam]() { - return createBiasParam(cellDim) + CNTK::Times(createProjectionParam(cellDim, inputDim), input); + auto projectInput = [input, cellDim, createBiasParam, createProjectionParam]() { + return createBiasParam(cellDim) + CNTK::Times(createProjectionParam(cellDim), input); }; // Input gate - auto it = CNTK::Sigmoid(projectInput() + CNTK::Times(createProjectionParam(cellDim, outputDim), stabilizedPrevOutput) + CNTK::ElementTimes(createDiagWeightParam(cellDim), stabilizedPrevCellState)); - auto bit = CNTK::ElementTimes(it, CNTK::Tanh(projectInput() + CNTK::Times(createProjectionParam(cellDim, outputDim), stabilizedPrevOutput))); + auto it = CNTK::Sigmoid(projectInput() + CNTK::Times(createProjectionParam(cellDim), stabilizedPrevOutput) + CNTK::ElementTimes(createDiagWeightParam(cellDim), stabilizedPrevCellState)); + auto bit = CNTK::ElementTimes(it, CNTK::Tanh(projectInput() + CNTK::Times(createProjectionParam(cellDim), stabilizedPrevOutput))); // Forget-me-not gate - auto ft = CNTK::Sigmoid(projectInput() + CNTK::Times(createProjectionParam(cellDim, outputDim), stabilizedPrevOutput) + CNTK::ElementTimes(createDiagWeightParam(cellDim), stabilizedPrevCellState)); + auto ft = CNTK::Sigmoid(projectInput() + CNTK::Times(createProjectionParam(cellDim), stabilizedPrevOutput) + CNTK::ElementTimes(createDiagWeightParam(cellDim), stabilizedPrevCellState)); auto bft = CNTK::ElementTimes(ft, prevCellState); auto ct = bft + bit; // Output gate - auto ot = CNTK::Sigmoid(projectInput() + CNTK::Times(createProjectionParam(cellDim, outputDim), stabilizedPrevOutput) + CNTK::ElementTimes(createDiagWeightParam(cellDim), Stabilize(ct, device))); + auto ot = CNTK::Sigmoid(projectInput() + CNTK::Times(createProjectionParam(cellDim), stabilizedPrevOutput) + CNTK::ElementTimes(createDiagWeightParam(cellDim), Stabilize(ct, device))); auto ht = CNTK::ElementTimes(ot, CNTK::Tanh(ct)); auto c = ct; - auto h = (outputDim != cellDim) ? CNTK::Times(createProjectionParam(outputDim, cellDim), Stabilize(ht, device)) : ht; + auto h = (outputDim != cellDim) ? CNTK::Times(createProjectionParam(outputDim), Stabilize(ht, device)) : ht; return{ h, c }; } diff --git a/Tests/UnitTests/V2LibraryTests/FunctionTests.cpp b/Tests/UnitTests/V2LibraryTests/FunctionTests.cpp index a39044b2c..6167c0616 100644 --- a/Tests/UnitTests/V2LibraryTests/FunctionTests.cpp +++ b/Tests/UnitTests/V2LibraryTests/FunctionTests.cpp @@ -99,18 +99,14 @@ void TestReduceSum(size_t sampleRank, const DeviceDescriptor& device) // Test ReduceSum along a dynamic axis { - auto testReduceSum = [&sequences, &sequenceLengths, inputShape, sequencesValue, device](const Axis& axis) + auto testReduceSum = [&sequences, &sequenceLengths, inputShape, sequencesValue, device]() { - if (!axis.IsDynamicAxis()) - RuntimeError("Called the dynamic axis ReduceSum test with a static axis"); - - size_t maxActualSequenceLength = sequencesValue->Shape()[inputShape.Rank()]; size_t numSequences = sequencesValue->Shape()[inputShape.Rank() + 1]; auto inputVar = InputVariable({ inputShape }, DataType::Float, L"input"); - FunctionPtr reduceSumFunc = ReduceSum(inputVar, axis); + FunctionPtr reduceSumFunc = Sequence::ReduceSum(inputVar); - NDShape maskShape = { ((axis == Axis::DefaultBatchAxis()) ? maxActualSequenceLength : 1), ((axis == Axis::DefaultBatchAxis()) ? 1 : numSequences) }; + NDShape maskShape = { 1, numSequences }; NDShape outputShape = reduceSumFunc->Output().Shape(); auto outputDataShape = outputShape.AppendShape(maskShape); @@ -130,10 +126,7 @@ void TestReduceSum(size_t sampleRank, const DeviceDescriptor& device) for (size_t k = 0; k < inputShape.TotalSize(); ++k) { float value = sequences[i][(j * inputShape.TotalSize()) + k]; - if (axis == Axis::DefaultBatchAxis()) - expectedTotals[(j * inputShape.TotalSize()) + k] += value; - else - expectedTotals[(i * inputShape.TotalSize()) + k] += value; + expectedTotals[(i * inputShape.TotalSize()) + k] += value; } } } @@ -141,7 +134,7 @@ void TestReduceSum(size_t sampleRank, const DeviceDescriptor& device) FloatingPointVectorCompare(outputData, expectedTotals, "testReduceSum: Forward prop results do not match expected results"); }; - testReduceSum(Axis::DefaultDynamicAxis()); + testReduceSum(); } } @@ -217,11 +210,8 @@ void TestSlice(size_t sampleRank, const DeviceDescriptor& device) // Test slice along a dynamic axis { - auto testDynamicAxisSlice = [&sequences, &sequenceLengths, inputShape, sequencesValue, device](const Axis& axis, int beginOffset, int endOffset) + auto testDynamicAxisSlice = [&sequences, &sequenceLengths, inputShape, sequencesValue, device](int beginOffset, int endOffset) { - if (!axis.IsDynamicAxis()) - RuntimeError("Called the dynamic axis slice test with a static axis"); - size_t maxActualSequenceLength = sequencesValue->Shape()[inputShape.Rank()]; size_t numSequences = sequencesValue->Shape()[inputShape.Rank() + 1]; @@ -229,11 +219,11 @@ void TestSlice(size_t sampleRank, const DeviceDescriptor& device) size_t maxSliceLength = (endAndBeginOffsetDiff > 0) ? endAndBeginOffsetDiff : maxActualSequenceLength + endAndBeginOffsetDiff; auto inputVar = InputVariable(inputShape, DataType::Float, L"input"); - auto sliceFunc = Slice(inputVar, axis, beginOffset, endOffset); + auto sliceFunc = Sequence::Slice(inputVar, beginOffset, endOffset); sliceFunc = sliceFunc + sliceFunc; - size_t outputSequenceAxisLength = (axis == Axis::DefaultDynamicAxis()) ? maxSliceLength : maxActualSequenceLength; - size_t outputBatchAxisLength = (axis == Axis::DefaultBatchAxis()) ? maxSliceLength : numSequences; + size_t outputSequenceAxisLength = maxSliceLength; + size_t outputBatchAxisLength = numSequences; NDShape outputShape = sliceFunc->Output().Shape().AppendShape({ outputSequenceAxisLength, outputBatchAxisLength }); std::vector outputData(outputShape.TotalSize(), 0); NDMaskPtr mask; @@ -247,15 +237,15 @@ void TestSlice(size_t sampleRank, const DeviceDescriptor& device) std::unordered_map outputs = { { sliceFunc->Output(), outputValue } }; sliceFunc->Forward({ { inputVar, sequencesValue } }, outputs, device); - size_t startSequenceIdx = (axis == Axis::DefaultBatchAxis()) ? ((beginOffset >= 0) ? beginOffset : (numSequences + beginOffset)) : 0; - size_t endSequenceIdx = (axis == Axis::DefaultBatchAxis()) ? ((endOffset > 0) ? endOffset : (numSequences + endOffset)) : numSequences; + size_t startSequenceIdx = 0; + size_t endSequenceIdx = numSequences; std::vector expectedOutputValues(inputShape.TotalSize() * outputSequenceAxisLength * outputBatchAxisLength); for (size_t i = startSequenceIdx; i < endSequenceIdx; ++i) { size_t currentSequenceLength = sequenceLengths[i]; - size_t startFrameIdx = (axis == Axis::DefaultDynamicAxis()) ? ((beginOffset >= 0) ? beginOffset : (currentSequenceLength + beginOffset)) : 0; - size_t endFrameIdx = (axis == Axis::DefaultDynamicAxis()) ? ((endOffset > 0) ? endOffset : (currentSequenceLength + endOffset)) : currentSequenceLength; + size_t startFrameIdx = ((beginOffset >= 0) ? beginOffset : (currentSequenceLength + beginOffset)); + size_t endFrameIdx = ((endOffset > 0) ? endOffset : (currentSequenceLength + endOffset)); size_t j = startFrameIdx; for (; j < endFrameIdx; ++j) { @@ -272,12 +262,12 @@ void TestSlice(size_t sampleRank, const DeviceDescriptor& device) FloatingPointVectorCompare(outputData, expectedOutputValues, "testDynamicAxisSlice: Forward prop results do not match expected results"); }; - testDynamicAxisSlice(Axis::DefaultDynamicAxis(), 0, 1); - testDynamicAxisSlice(Axis::DefaultDynamicAxis(), 0, 2); - testDynamicAxisSlice(Axis::DefaultDynamicAxis(), -1, 0); - testDynamicAxisSlice(Axis::DefaultDynamicAxis(), -2, 0); - testDynamicAxisSlice(Axis::DefaultDynamicAxis(), 0, -1); - testDynamicAxisSlice(Axis::DefaultDynamicAxis(), 1, 0); + testDynamicAxisSlice(0, 1); + testDynamicAxisSlice(0, 2); + testDynamicAxisSlice(-1, 0); + testDynamicAxisSlice(-2, 0); + testDynamicAxisSlice(0, -1); + testDynamicAxisSlice(1, 0); } } diff --git a/Tests/UnitTests/V2LibraryTests/Seq2Seq.cpp b/Tests/UnitTests/V2LibraryTests/Seq2Seq.cpp index 643fc33a9..a81f3c050 100644 --- a/Tests/UnitTests/V2LibraryTests/Seq2Seq.cpp +++ b/Tests/UnitTests/V2LibraryTests/Seq2Seq.cpp @@ -6,7 +6,7 @@ using namespace CNTK; using namespace std::placeholders; -void TrainSequenceToSequenceTranslator(const DeviceDescriptor& device, bool useSparseInputs, bool testSaveAndReLoad, bool testCheckpointing, bool addBeamSearchReorderingHook, bool testCloning) +void TrainSequenceToSequenceTranslator(const DeviceDescriptor& device, bool useSparseInputs, bool testSaveAndReLoad, bool testCheckpointing, bool addBeamSearchReorderingHook, bool testCloning, bool usePlaceholders) { using namespace std::placeholders; @@ -30,7 +30,7 @@ void TrainSequenceToSequenceTranslator(const DeviceDescriptor& device, bool useS FunctionPtr inputSequence = Alias(rawInput, L"inputSequence"); // Drop the sentence start token from the label, for decoder training - auto labelSequence = Slice(rawLabels, labelDynamicAxes[0], 1, 0, L"labelSequenceWithStartTrimmed"); + auto labelSequence = Sequence::Slice(rawLabels, 1, 0, L"labelSequenceWithStartTrimmed"); auto labelSentenceStart = Sequence::First(rawLabels, L"labelSequenceStart"); auto isFirstLabel = Sequence::IsFirst(labelSequence, L"isFirstLabel"); @@ -38,8 +38,8 @@ void TrainSequenceToSequenceTranslator(const DeviceDescriptor& device, bool useS bool forceEmbedding = useSparseInputs; /* Embeddings */ - auto inputEmbeddingWeights = Parameter({ inputEmbeddingDim, inputVocabDim }, DataType::Float, GlorotUniformInitializer(), device, L"inputEmbeddingWeights"); - auto labelEmbeddingWeights = Parameter({ labelEmbeddingDim, labelVocabDim }, DataType::Float, GlorotUniformInitializer(), device, L"labelEmbeddingWeights"); + auto inputEmbeddingWeights = Parameter({ inputEmbeddingDim, NDShape::InferredDimension }, DataType::Float, GlorotUniformInitializer(), device, L"inputEmbeddingWeights"); + auto labelEmbeddingWeights = Parameter({ labelEmbeddingDim, NDShape::InferredDimension }, DataType::Float, GlorotUniformInitializer(), device, L"labelEmbeddingWeights"); auto inputEmbedding = Alias((!forceEmbedding && (inputVocabDim <= inputEmbeddingDim)) ? inputSequence : Times(inputEmbeddingWeights, inputSequence), L"inputEmbedding"); auto labelEmbedding = Alias((!forceEmbedding && (labelVocabDim <= labelEmbeddingDim)) ? labelSequence : Times(labelEmbeddingWeights, labelSequence), L"labelEmbedding"); @@ -63,8 +63,20 @@ void TrainSequenceToSequenceTranslator(const DeviceDescriptor& device, bool useS labelSentenceStartEmbeddedScattered = Reshape(labelSentenceStartEmbeddedScattered, labelSentenceStartEmbeddedScattered->Output().Shape().AppendShape({ 1 }), L"labelSentenceStartEmbeddedScattered"); } - auto thoughtVectorBroadcastH = Sequence::BroadcastAs(thoughtVectorH, labelEmbedding, L"thoughtVectorBroadcastH"); - auto thoughtVectorBroadcastC = Sequence::BroadcastAs(thoughtVectorC, labelEmbedding, L"thoughtVectorBroadcastC"); + auto actualThoughtVectorBroadcastH = Sequence::BroadcastAs(thoughtVectorH, labelEmbedding, L"thoughtVectorBroadcastH"); + auto actualThoughtVectorBroadcastC = Sequence::BroadcastAs(thoughtVectorC, labelEmbedding, L"thoughtVectorBroadcastC"); + + Variable thoughtVectorBroadcastH, thoughtVectorBroadcastC; + if (usePlaceholders) + { + thoughtVectorBroadcastH = PlaceholderVariable(); + thoughtVectorBroadcastC = PlaceholderVariable(); + } + else + { + thoughtVectorBroadcastH = actualThoughtVectorBroadcastH; + thoughtVectorBroadcastC = actualThoughtVectorBroadcastC; + } /* Decoder */ auto beamSearchReorderHook = Constant({ 1, 1 }, 1.0f, device); @@ -116,6 +128,10 @@ void TrainSequenceToSequenceTranslator(const DeviceDescriptor& device, bool useS auto biasWeights = Parameter({ labelVocabDim }, 0.0f, device); auto z = Plus(Times(outputLayerProjWeights, Stabilize(decoderOutput, device)), biasWeights, L"classifierOutput"); + + if (usePlaceholders) + z->ReplacePlaceholders({ { thoughtVectorBroadcastH, actualThoughtVectorBroadcastH }, { thoughtVectorBroadcastC, actualThoughtVectorBroadcastC } }); + auto ce = CrossEntropyWithSoftmax(z, labelSequence, L"lossFunction"); auto errs = ClassificationError(z, labelSequence, L"classificationError"); @@ -218,8 +234,8 @@ void TrainSequenceToSequenceTranslator() fprintf(stderr, "\nTrainSequenceToSequenceTranslator..\n"); // TODO: Also test with sparse input variables in the graph - TrainSequenceToSequenceTranslator(DeviceDescriptor::CPUDevice(), false, true, false, true, true); + TrainSequenceToSequenceTranslator(DeviceDescriptor::CPUDevice(), false, true, false, false, true, true); if (IsGPUAvailable()) - TrainSequenceToSequenceTranslator(DeviceDescriptor::GPUDevice(0), false, false, true, false, false); + TrainSequenceToSequenceTranslator(DeviceDescriptor::GPUDevice(0), false, false, true, true, false, false); } diff --git a/bindings/python/cntk/cntk_py.i b/bindings/python/cntk/cntk_py.i index 53f871f93..ed741caf5 100644 --- a/bindings/python/cntk/cntk_py.i +++ b/bindings/python/cntk/cntk_py.i @@ -19,6 +19,8 @@ %rename(gpu_device) CNTK::DeviceDescriptor::GPUDevice; %rename(cpu_device) CNTK::DeviceDescriptor::CPUDevice; %rename(times_transpose) CNTK::TransposeTimes; +%rename(sequence_slice) CNTK::Sequence::Slice; +%rename(sequence_reduce_sum) CNTK::Sequence::ReduceSum; %rename(momentum_as_time_constant_schedule) CNTK::MomentumAsTimeConstantSchedule; diff --git a/bindings/python/cntk/ops/sequence/__init__.py b/bindings/python/cntk/ops/sequence/__init__.py index 42b7e5399..a15732775 100644 --- a/bindings/python/cntk/ops/sequence/__init__.py +++ b/bindings/python/cntk/ops/sequence/__init__.py @@ -63,6 +63,28 @@ def is_last(seq, name=''): seq = sanitize_input(seq, get_data_type(seq)) return is_last(seq, name) +@typemap +def slice(seq, begin_index, end_index, name=''): + ''' + Slice the input sequence. + + Examples: + TBA + Args: + seq: sequence input tensor + begin_index (`int`): the index along sequence axis where the slicing starts + end_index (`int`): the index along sequence axis where the slicing ends + name (`str`, optional): the name of the Function instance in the network + + See also: + Indexing in NumPy: http://docs.scipy.org/doc/numpy/reference/arrays.indexing.html + + Returns: + :class:`cntk.ops.functions.Function` + ''' + from cntk.cntk_py import sequence_slice + seq = sanitize_input(seq, get_data_type(seq)) + return sequence_slice(seq, begin_index, end_index, name) @typemap def first(seq, name=''): @@ -281,3 +303,21 @@ def broadcast_as(operand, broadcast_as_operand, name=''): broadcast_as_operand = sanitize_input( broadcast_as_operand, get_data_type(broadcast_as_operand)) return broadcast_as(operand, broadcast_as_operand, name) + +@typemap +def reduce_sum(seq, name=''): + ''' + Computes the sum of the input sequence's elements across the sequence axis. + + Examples: + TBA + Args: + seq: sequence input tensor + name (`str`, optional): the name of the Function instance in the network + + Returns: + :class:`cntk.ops.functions.Function` + ''' + from cntk.cntk_py import sequence_reduce_sum + seq = sanitize_input(seq, get_data_type(seq)) + return sequence_reduce_sum(seq, name) diff --git a/bindings/python/cntk/ops/tests/reshaping_test.py b/bindings/python/cntk/ops/tests/reshaping_test.py index 096337dcd..4717bb21e 100644 --- a/bindings/python/cntk/ops/tests/reshaping_test.py +++ b/bindings/python/cntk/ops/tests/reshaping_test.py @@ -166,7 +166,7 @@ def test_op_slice_sequence(input_data, slice_params, expected_result, device_id, dynamic_axes=[Axis.default_batch_axis(), t], name='a') - result = C.slice(a, axis=t, begin_index=slice_params[ + result = C.sequence.slice(a, begin_index=slice_params[ 0], end_index=slice_params[1]) def grad_slice(x, beg_index, end_index): diff --git a/bindings/python/examples/Sequence2Sequence/Sequence2Sequence.py b/bindings/python/examples/Sequence2Sequence/Sequence2Sequence.py index 3b69275ee..90b53f3a5 100644 --- a/bindings/python/examples/Sequence2Sequence/Sequence2Sequence.py +++ b/bindings/python/examples/Sequence2Sequence/Sequence2Sequence.py @@ -11,7 +11,7 @@ from cntk import Trainer, Axis, save_model, load_model #, text_format_minibatch_ from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs, INFINITELY_REPEAT, FULL_DATA_SWEEP from cntk.device import cpu, set_default_device from cntk.learner import momentum_sgd, momentum_as_time_constant_schedule -from cntk.ops import input_variable, cross_entropy_with_softmax, classification_error, sequence, slice, past_value, future_value, element_select, alias, hardmax +from cntk.ops import input_variable, cross_entropy_with_softmax, classification_error, sequence, past_value, future_value, element_select, alias, hardmax from cntk.ops.functions import CloneMethod abs_path = os.path.dirname(os.path.abspath(__file__)) @@ -94,7 +94,7 @@ def sequence_to_sequence_translator(debug_output=False, run_test=False): input_sequence = raw_input # Drop the sentence start token from the label, for decoder training - label_sequence = slice(raw_labels, label_seq_axis, 1, 0) # A B C --> A B C + label_sequence = sequence.slice(raw_labels, 1, 0) # A B C --> A B C label_sentence_start = sequence.first(raw_labels) # is_first_label = sequence.is_first(label_sequence) # 0 0 0 ... @@ -239,7 +239,7 @@ def sequence_to_sequence_translator(debug_output=False, run_test=False): z = load_model("seq2seq.dnn") label_seq_axis = Axis('labelAxis') - label_sequence = slice(find_arg_by_name('raw_labels',z), label_seq_axis, 1, 0) + label_sequence = sequence.slice(find_arg_by_name('raw_labels',z), 1, 0) ce = cross_entropy_with_softmax(z, label_sequence) errs = classification_error(z, label_sequence) trainer = Trainer(z, ce, errs, [momentum_sgd( diff --git a/bindings/python/examples/SequenceClassification/SequenceClassification.py b/bindings/python/examples/SequenceClassification/SequenceClassification.py index e1e274400..e8a6e8192 100644 --- a/bindings/python/examples/SequenceClassification/SequenceClassification.py +++ b/bindings/python/examples/SequenceClassification/SequenceClassification.py @@ -10,11 +10,11 @@ from cntk import Trainer, Axis #, text_format_minibatch_source, StreamConfigurat from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs, INFINITELY_REPEAT, FULL_DATA_SWEEP from cntk.device import cpu, set_default_device from cntk.learner import sgd -from cntk.ops import input_variable, cross_entropy_with_softmax, classification_error +from cntk.ops import input_variable, cross_entropy_with_softmax, classification_error, sequence abs_path = os.path.dirname(os.path.abspath(__file__)) sys.path.append(os.path.join(abs_path, "..", "..")) -from examples.common.nn import LSTMP_component_with_self_stabilization, embedding, linear_layer, select_last, print_training_progress +from examples.common.nn import LSTMP_component_with_self_stabilization, embedding, linear_layer, print_training_progress # Creates the reader def create_reader(path, is_training, input_dim, label_dim): @@ -28,7 +28,7 @@ def LSTM_sequence_classifer_net(input, num_output_classes, embedding_dim, LSTM_d embedding_function = embedding(input, embedding_dim) LSTM_function = LSTMP_component_with_self_stabilization( embedding_function.output, LSTM_dim, cell_dim)[0] - thought_vector = select_last(LSTM_function) + thought_vector = sequence.last(LSTM_function) return linear_layer(thought_vector, num_output_classes)