From 3723b405b06c113825b13e453fbfbdf5e2791e04 Mon Sep 17 00:00:00 2001 From: Amit Agarwal Date: Wed, 23 Nov 2016 18:57:17 -0800 Subject: [PATCH] CNTK v2 library: Refactor Function.h and .cpp into multiple files --- Makefile | 2 + Source/CNTKv2LibraryDll/BackCompat.cpp | 3 +- .../CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj | 5 +- .../CNTKv2LibraryDll.vcxproj.filters | 5 +- Source/CNTKv2LibraryDll/CompositeFunction.cpp | 1515 ++++++++++++ Source/CNTKv2LibraryDll/CompositeFunction.h | 267 ++ .../ComputeInputStatistics.cpp | 2 +- Source/CNTKv2LibraryDll/Function.cpp | 2137 +---------------- Source/CNTKv2LibraryDll/MinibatchSource.cpp | 1 - Source/CNTKv2LibraryDll/PrimitiveFunction.cpp | 654 +++++ .../{Function.h => PrimitiveFunction.h} | 255 -- Source/CNTKv2LibraryDll/Trainer.cpp | 1 - Source/CNTKv2LibraryDll/Utils.cpp | 2 +- Source/CNTKv2LibraryDll/Value.cpp | 2 +- Source/CNTKv2LibraryDll/Value.h | 1 + Source/CNTKv2LibraryDll/Variable.cpp | 2 +- 16 files changed, 2455 insertions(+), 2399 deletions(-) create mode 100644 Source/CNTKv2LibraryDll/CompositeFunction.cpp create mode 100644 Source/CNTKv2LibraryDll/CompositeFunction.h create mode 100644 Source/CNTKv2LibraryDll/PrimitiveFunction.cpp rename Source/CNTKv2LibraryDll/{Function.h => PrimitiveFunction.h} (70%) diff --git a/Makefile b/Makefile index ace984989..0066d29e5 100644 --- a/Makefile +++ b/Makefile @@ -409,6 +409,8 @@ CNTKLIBRARY_COMMON_SRC =\ $(SOURCEDIR)/CNTKv2LibraryDll/BackCompat.cpp \ $(SOURCEDIR)/CNTKv2LibraryDll/Common.cpp \ $(SOURCEDIR)/CNTKv2LibraryDll/Function.cpp \ + $(SOURCEDIR)/CNTKv2LibraryDll/PrimitiveFunction.cpp \ + $(SOURCEDIR)/CNTKv2LibraryDll/CompositeFunction.cpp \ $(SOURCEDIR)/CNTKv2LibraryDll/NDArrayView.cpp \ $(SOURCEDIR)/CNTKv2LibraryDll/NDMask.cpp \ $(SOURCEDIR)/CNTKv2LibraryDll/Trainer.cpp \ diff --git a/Source/CNTKv2LibraryDll/BackCompat.cpp b/Source/CNTKv2LibraryDll/BackCompat.cpp index 5e679af13..56e3702d7 100644 --- a/Source/CNTKv2LibraryDll/BackCompat.cpp +++ b/Source/CNTKv2LibraryDll/BackCompat.cpp @@ -6,7 +6,8 @@ #include "stdafx.h" #include "CNTKLibrary.h" #include "BackCompat.h" -#include "Function.h" +#include "PrimitiveFunction.h" +#include "CompositeFunction.h" #include "ComputationNetworkBuilder.h" #include "Utils.h" #include "ComputationNode.h" diff --git a/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj b/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj index 4711fbb60..b458a374d 100644 --- a/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj +++ b/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj @@ -135,12 +135,13 @@ + - + @@ -151,6 +152,7 @@ + @@ -165,6 +167,7 @@ + diff --git a/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj.filters b/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj.filters index 6eaa7aa78..1cd5b97e6 100644 --- a/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj.filters +++ b/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj.filters @@ -22,6 +22,8 @@ + + @@ -33,7 +35,6 @@ API - @@ -46,6 +47,8 @@ + + diff --git a/Source/CNTKv2LibraryDll/CompositeFunction.cpp b/Source/CNTKv2LibraryDll/CompositeFunction.cpp new file mode 100644 index 000000000..5f365e569 --- /dev/null +++ b/Source/CNTKv2LibraryDll/CompositeFunction.cpp @@ -0,0 +1,1515 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE.md file in the project root for full license information. +// + +#include "stdafx.h" +#include "CNTKLibrary.h" +#include "CompositeFunction.h" +#include "ComputationNetworkBuilder.h" +#include "Utils.h" +#include "ComputationNode.h" +#include "ReshapingNodes.h" +#include "EvaluationNodes.h" +#include "TrainingNodes.h" +#include "LinearAlgebraNodes.h" +#include "InputAndParamNodes.h" +#include "NonlinearityNodes.h" +#include "RecurrentNodes.h" +#include "Serialization.h" +#include "Value.h" +#include "RNNNodes.h" + +using namespace Microsoft::MSR::CNTK; + +namespace CNTK +{ + /*static*/ const std::wstring CompositeFunction::CompositeFunctionOpName = L"CompositeFunctionOpName"; + /*static*/ std::atomic CompositeFunction::s_nextAutoGeneratedDynamicAxis(0); + + static const std::wstring s_compositeFunctionTypeValue = L"CompositeFunction"; + + /*virtual*/ Dictionary CompositeFunction::Serialize() const + { + Dictionary dict; + + dict[versionKey] = CurrentVersion(); + dict[typeKey] = s_compositeFunctionTypeValue; + dict[rootKey] = RootFunction()->Uid(); + dict[nameKey] = Name(); + dict[uidKey] = Uid(); + + + // Find cycles in the graph and "break" them by inserting placeholders. + // This needs to be done on Save, since here we have easy access to the shape and + // dynamic axis info. + std::unordered_set visitedFunctions; + std::vector topoSortedPrimitiveFunctions; + std::vector inputs; + std::unordered_set inputUids; + Traverse(RootFunction(), [&visitedFunctions, &inputs, &topoSortedPrimitiveFunctions, &inputUids](const FunctionPtr& function) { + std::vector functionInputs = function->Inputs(); + for (const auto& input : functionInputs) + { + auto& uid = input.Uid(); + if (inputUids.find(uid) != inputUids.end()) + { + continue; + } + + // check if this input corresponds to a cyclic edge in the graph. + bool mustBeReplaced = input.IsOutput() && visitedFunctions.find(input.Owner()) != visitedFunctions.end(); + + if (mustBeReplaced) + { + auto varKind = VariableKind::Placeholder; + Variable var(input.Shape(), varKind, input.GetDataType(), nullptr, + input.IsSparse(), input.DynamicAxes(), input.Name(), uid); + inputs.push_back(var); + inputUids.insert(uid); + } + else if (!input.IsOutput()) + { + // leave the input as is. + inputs.push_back(input); + inputUids.insert(uid); + } + } + visitedFunctions.insert(function); + topoSortedPrimitiveFunctions.push_back(function); + }); + + std::reverse(std::begin(topoSortedPrimitiveFunctions), std::end(topoSortedPrimitiveFunctions)); + + assert(topoSortedPrimitiveFunctions.size() == m_allPrimitiveFunctions.size()); + assert(topoSortedPrimitiveFunctions.back()->Uid() == RootFunction()->Uid()); + + std::vector inputDictionaries; + inputDictionaries.reserve(inputs.size()); + inputUids.clear(); + for (const auto& input : inputs) + { + if (inputUids.find(input.Uid()) != inputUids.end()) + { + LogicError("Input uids must be unique"); + } + inputUids.insert(input.Uid()); + inputDictionaries.push_back(input.Serialize()); + } + + dict[inputsKey] = std::move(inputDictionaries); + + std::vector functionDictionaries; + std::unordered_set outputUids; + for (const auto& primitiveFunciton : topoSortedPrimitiveFunctions) + { + for (const auto& output : primitiveFunciton->Outputs()) + { + if (outputUids.find(output.Uid()) != outputUids.end()) + { + LogicError("Output uids of all primitive functions in a function graph must be unique"); + } + outputUids.insert(primitiveFunciton->Uid()); + } + functionDictionaries.push_back(primitiveFunciton->Serialize()); + } + + dict[functionsKey] = std::move(functionDictionaries); + + return dict; + } + + /*static*/ FunctionPtr CompositeFunction::Deserialize(const Dictionary& dict, const CNTK::DeviceDescriptor& device) + { + static const vector s_requiredDictionaryKeys = { typeKey, rootKey, nameKey, uidKey, inputsKey, functionsKey }; + + size_t version = ValidateDictionary(dict, s_requiredDictionaryKeys, s_compositeFunctionTypeValue, s_serializationVersion); + + const auto& rootUid = dict[rootKey].Value(); + const auto& name = dict[nameKey].Value(); + const auto& uid = dict[uidKey].Value(); + const auto& inputs = dict[inputsKey].Value>(); + + std::unordered_map uidToInputMap(inputs.size()); + + for (const auto& dictionaryValue : inputs) + { + const auto& dictionary = dictionaryValue.Value(); + const auto& inputVar = Variable::Deserialize(dictionary, device); + + if (uidToInputMap.find(inputVar.Uid()) != uidToInputMap.end()) + { + LogicError("Input uids are not unique (several inputs share '%ls' uid) " + "(%s).", inputVar.Uid().c_str(), GetVersionsString(s_serializationVersion, version).c_str()); + } + uidToInputMap[inputVar.Uid()] = inputVar; + } + + const auto& functions = dict[functionsKey].Value>(); + + FunctionPtr root; + std::unordered_map placeholderReplacements; + std::unordered_set allPrimitiveFunctions; // this keeps all primitive functions alive until a composite function is created. + for (const auto& dictionaryValue : functions) + { + root = PrimitiveFunction::Deserialize(dictionaryValue.Value(), uidToInputMap, device); + allPrimitiveFunctions.insert(root); + + auto primitiveFunction = dynamic_cast(root.get()); + // Since Combine simply forwards other functions' outputs, all of its outputs + // should already be in the uidToInputMap. + if (primitiveFunction->OpType() == PrimitiveOpType::Combine) + { + continue; + } + + for (const auto& output : root->Outputs()) + { + const auto& it = uidToInputMap.find(output.Uid()); + if (it != uidToInputMap.end()) + { + if (!it->second.IsPlaceholder()) + { + LogicError("Unexpected variable type %ls instead of a Placeholder for input %ls variable (uid = %ls)" + "(%s).", VariableKindName(it->second.Kind()), it->second.Name().c_str(), it->second.Uid().c_str(), + GetVersionsString(s_serializationVersion, version).c_str()); + } + placeholderReplacements[it->second] = output; + } + else + { + uidToInputMap[output.Uid()] = output; + } + } + } + + if (root->Uid() != rootUid) + { + LogicError("Root UID '%ls' is different from the expected value '%ls'.", root->Uid().c_str(), rootUid.c_str()); + } + + if (placeholderReplacements.size() > 0) + { + return CompositeFunction::Create(root->ReplacePlaceholders(placeholderReplacements), name, uid); + } + + return CompositeFunction::Create(root, name, uid); + } + + // Names of the dynamic axes in the CNTK engine for some special sets of dynamic axes values + // Note: The no sequence axis corresponds to a special case where there is no sequence axis (i.e. has been reduced over) + // and the special name is used to identify this when loading back a model saved in CNTK v1 format. This will not really be needed + // when the new CNTK v2 model serialization format is ready. + /*static*/ const std::wstring CompositeFunction::InternalDefaultDynamicAxisName = L"*"; + /*static*/ const std::wstring CompositeFunction::InternalNoSequenceAxisName = L"__noSequenceAxis"; + + // Replace any PlaceHolder Variables in the graph of Functions underlying 'this' CompositeFunction. All PlaceHolder variables + // should have been replaced before performing any Forward compute of 'this' Function. + /*virtual*/ void CompositeFunction::ReplacePlaceholdersInPlace(const std::unordered_map& placeholderReplacements, + std::unordered_set& visitedFunctions, + std::unordered_set& replacedPlaceholders) + { + RootFunction()->ReplacePlaceholdersInPlace(placeholderReplacements, visitedFunctions, replacedPlaceholders); + + // If any of the placeholders were replaced with Output variables, let's add the graph of function underneath each of those to 'm_allPrimitiveFunctions' set + for (auto replacedPlaceholder : replacedPlaceholders) + { + auto replacingVariable = placeholderReplacements.at(replacedPlaceholder); + if (replacingVariable.IsOutput()) + { + auto ownerFunc = replacingVariable.Owner(); + std::unordered_set visitedFunctions; + Collect(ownerFunc, visitedFunctions); + + // Add the newly visited functions to 'm_allPrimitiveFunctions' set + m_allPrimitiveFunctions.insert(visitedFunctions.begin(), visitedFunctions.end()); + } + } + std::unordered_map functionVisitCounts; + + // An arbitrary cap on changing output shape of recurrent nodes, to detect infinite inference loops + const size_t maxNumValidationPassesAllowed = 25; + bool recurrentNodeOutputModified = false; + size_t numValidationPasses = 0; + do + { + recurrentNodeOutputModified = false; + functionVisitCounts.clear(); + RootFunction()->ValidateOrUpdateOutputs(functionVisitCounts, recurrentNodeOutputModified); + numValidationPasses++; + } while (recurrentNodeOutputModified && (numValidationPasses < maxNumValidationPassesAllowed)); + + if (numValidationPasses >= maxNumValidationPassesAllowed) + LogicError("A recurrent node output shape change happened in successive %d validation passes indicating a potential infinite inference loop!", (int)numValidationPasses); + } + + // Recursively create a sub-network of ComputationNode instances corresponding to the graph of Functions + // underlying the specified 'variable' and return the ComputationNode instance that corresponds to the + // top level 'variable' + template + /*static*/ ComputationNodeBasePtr CompositeFunction::GetNode(const Variable& variable, + Microsoft::MSR::CNTK::ComputationNetworkPtr& network, + ComputationNetworkBuilder& builder, + std::unordered_map& variableToNodeMap, + std::unordered_map& isVariableRootMap) + { + auto iter = variableToNodeMap.find(variable); + if (iter != variableToNodeMap.end()) + { + isVariableRootMap[variable] = false; + return iter->second; + } + + // The DataType, Shape and DynamicAxes of the variable must be known by now + if (variable.GetDataType() == DataType::Unknown) + InvalidArgument("Variable%S with unknown DataType detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str()); + + if (variable.Shape().IsUnknown()) + InvalidArgument("Variable%S with unknown shape detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str()); + + if (variable.Shape().HasInferredDimension()) + InvalidArgument("Variable%S with InferredDimension for at least one axis in its shape, detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str()); + + if (variable.DynamicAxes() == Axis::UnknownDynamicAxes()) + InvalidArgument("Variable%S with unknown dynamic axes detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str()); + + // Lets add a null entry in the map for this variable, to break infinite recursion when processing recurrent graphs + variableToNodeMap[variable] = nullptr; + + std::shared_ptr> computationNodePtr; + if (variable.IsParameter() || variable.IsConstant()) + { + auto internalNodeName = CNTKInternalNodeNameFromUidAndName(variable.Uid(), variable.Name()); + computationNodePtr = builder.CreateLearnableParameter(internalNodeName, AsTensorShape(variable.Shape())); + network->InitLearnableParameters(computationNodePtr, L"fixedValue", 0); // must call this to follow protocol; can overwrite later + if (!variable.NeedsGradient()) + computationNodePtr->SetLearningRateMultiplier(0.0); + + NDArrayViewPtr value = variable.IsConstant() ? Constant(variable).Value() : Parameter(variable).Value(); + std::shared_ptr> valueMatrix = variable.IsConstant() ? value->GetMatrix() : value->GetWritableMatrix(); + + if (variable.IsParameter() || (valueMatrix->GetDeviceId() == network->GetDeviceId())) + computationNodePtr->Value() = valueMatrix->AsReference(); + else + { + Matrix clonedMatrix(valueMatrix->GetNumRows(), valueMatrix->GetNumCols(), network->GetDeviceId(), valueMatrix->GetMatrixType(), valueMatrix->GetFormat()); + clonedMatrix.AssignValuesOf(*valueMatrix); + computationNodePtr->Value() = std::move(clonedMatrix); + } + } + else if (variable.IsInput()) + { + auto internalNodeName = CNTKInternalNodeNameFromUidAndName(variable.Uid(), variable.Name()); + + // TODO: Input variables currently are required to have the default batch axis + auto dynamicAxes = variable.DynamicAxes(); + auto foundDefaultBatchAxis = std::find(dynamicAxes.begin(), dynamicAxes.end(), Axis::DefaultBatchAxis()); + if (foundDefaultBatchAxis == dynamicAxes.end()) + LogicError("Currently Input Variables are required to have the DefaultBatchAxis as one of their dynamic axes"); + + if (dynamicAxes.back() != Axis::DefaultBatchAxis()) + LogicError("Currently Input Variables are required to have the DefaultBatchAxis as their last dynamic axes"); + + // TODO: Support inputs with > 1 dynamic axes + if ((dynamicAxes.size() < 1) || (dynamicAxes.size() > 2)) + LogicError("Currently only Input variables with 1 or 2 dynamic axis are supported"); + + // Construct the dynamic axis name to be used internally for the CNTK InputNodes + std::wstring internalDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes); + + if (!internalDynamicAxisName.empty() && !network->NodeNameExists(internalDynamicAxisName)) + network->AddNodeToNetAndAttachInputs(New>(network->GetDeviceId(), internalDynamicAxisName), {}); + + if (IsSparseInput(variable)) + computationNodePtr = builder.CreateSparseInputNode(internalNodeName, AsTensorShape(variable.Shape()), internalDynamicAxisName); + else + computationNodePtr = builder.CreateInputNode(internalNodeName, AsTensorShape(variable.Shape()), internalDynamicAxisName); + + if (variable.NeedsGradient()) + { + // Set a dummy learning rate multiplier to force gradient computation for the input computation node since by default + // gradients are not computed for Input nodes + computationNodePtr->SetLearningRateMultiplier(0.00001f); + } + } + else + { + assert(variable.IsOutput()); + computationNodePtr = GetOutputVariableNode(variable, network, builder, variableToNodeMap, isVariableRootMap)->template As>()->shared_from_this(); + } + + variableToNodeMap[variable] = computationNodePtr; + if (isVariableRootMap.find(variable) == isVariableRootMap.end()) + isVariableRootMap[variable] = variable.IsOutput(); + + return computationNodePtr; + } + + template + /*static*/ ComputationNodeBasePtr CompositeFunction::CreateComputationNode(const Variable& variable, + PrimitiveFunction* primitiveFunction, + const std::vector>>& inputNodes, + Microsoft::MSR::CNTK::ComputationNetworkPtr& network, + std::unordered_map& variableToNodeMap) + { + ComputationNodeBasePtr computationNodePtr; + + auto internalNodeName = CNTKInternalNodeNameFromUidAndName(primitiveFunction->Uid(), primitiveFunction->Name()); + + auto& functionConfig = primitiveFunction->Attributes(); + auto functionInputs = primitiveFunction->Inputs(); + PrimitiveOpType op = primitiveFunction->OpType(); + + switch (op) + { + case PrimitiveOpType::Negate: + computationNodePtr = New>(network->GetDeviceId(), internalNodeName); + break; + case PrimitiveOpType::Sigmoid: + computationNodePtr = New>(network->GetDeviceId(), internalNodeName); + break; + case PrimitiveOpType::Tanh: + computationNodePtr = New>(network->GetDeviceId(), internalNodeName); + break; + case PrimitiveOpType::ReLU: + computationNodePtr = New>(network->GetDeviceId(), internalNodeName); + break; + case PrimitiveOpType::Exp: + computationNodePtr = New>(network->GetDeviceId(), internalNodeName); + break; + case PrimitiveOpType::Log: + computationNodePtr = New>(network->GetDeviceId(), internalNodeName); + break; + case PrimitiveOpType::Sqrt: + computationNodePtr = New>(network->GetDeviceId(), internalNodeName); + break; + case PrimitiveOpType::Floor: + computationNodePtr = New>(network->GetDeviceId(), internalNodeName); + break; + case PrimitiveOpType::Abs: + computationNodePtr = New>(network->GetDeviceId(), internalNodeName); + break; + case PrimitiveOpType::Reciprocal: + computationNodePtr = New>(network->GetDeviceId(), internalNodeName); + break; + case PrimitiveOpType::Softmax: + computationNodePtr = New>(network->GetDeviceId(), internalNodeName); + break; + case PrimitiveOpType::Hardmax: + computationNodePtr = New>(network->GetDeviceId(), internalNodeName); + break; + case PrimitiveOpType::TransposeAxes: + { + auto axis1 = functionConfig[PrimitiveFunction::AttributeNameAxis1].Value(); + auto axis2 = functionConfig[PrimitiveFunction::AttributeNameAxis2].Value(); + + // The axis ids passed to the internal CNTK TransposeDimensionsNode are 1 based instead of 0 based + computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsCNTKInternalAxisIdx(axis1), AsCNTKInternalAxisIdx(axis2)); + break; + } + case PrimitiveOpType::Where: + { + auto dynamicAxes = variable.DynamicAxes(); + auto internalCNTKWhereNodeDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes); + computationNodePtr = New>(network->GetDeviceId(), internalNodeName, internalCNTKWhereNodeDynamicAxisName); + break; + } + case PrimitiveOpType::Slice: + { + auto axis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value(); + auto beginIndex = functionConfig[PrimitiveFunction::AttributeNameBeginIndex].Value(); + auto endIndex = functionConfig[PrimitiveFunction::AttributeNameEndIndex].Value(); + + // Internal CNTK SliceNode takes 1 based axis indices instead of 0 based + computationNodePtr = New>(network->GetDeviceId(), internalNodeName, beginIndex, endIndex, AsCNTKInternalAxisIdx(axis)); + break; + } + case PrimitiveOpType::RandomSample: + { + auto numSamples = functionConfig[PrimitiveFunction::AttributeNameNumSamples].Value(); + auto allowDuplicates = functionConfig[PrimitiveFunction::AttributeNameAllowDuplicates].Value(); + computationNodePtr = New>(network->GetDeviceId(), internalNodeName, numSamples, allowDuplicates); + break; + } + case PrimitiveOpType::RandomSampleInclusionFrequency: + { + auto numSamples = functionConfig[PrimitiveFunction::AttributeNameNumSamples].Value(); + auto allowDuplicates = functionConfig[PrimitiveFunction::AttributeNameAllowDuplicates].Value(); + computationNodePtr = New>(network->GetDeviceId(), internalNodeName, numSamples, allowDuplicates); + break; + } + case PrimitiveOpType::Dropout: + { + auto dropoutRate = functionConfig[PrimitiveFunction::AttributeNameDropoutRate].Value(); + computationNodePtr = New>(network->GetDeviceId(), internalNodeName); + computationNodePtr->As>()->SetDropoutRate(dropoutRate); + break; + } + case PrimitiveOpType::Reshape: + { + auto newShape = functionConfig[PrimitiveFunction::AttributeNameNewShape].Value(); + computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsTensorShape(newShape)); + break; + } + case PrimitiveOpType::ROIPooling: + { + auto roiOutputShape = functionConfig[PrimitiveFunction::AttributeNameROIOutputShape].Value(); + computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsTensorShape(roiOutputShape)); + break; + } + case PrimitiveOpType::Pooling: + { + PoolingType poolingType = (PoolingType)(functionConfig[PrimitiveFunction::AttributeNamePoolingType].Value()); + auto poolingWindowsShape = functionConfig[PrimitiveFunction::AttributeNamePoolingWindowShape].Value(); + auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value(); + auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value(); + auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value(); + auto autoPadding = AsVector(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value>()); + computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsCNTKPoolKind(poolingType), AsTensorShape(poolingWindowsShape), AsTensorShape(strides), autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), ImageLayoutKind::CHW); + break; + } + case PrimitiveOpType::SumAll: + computationNodePtr = New>(network->GetDeviceId(), internalNodeName); + break; + case PrimitiveOpType::Plus: + computationNodePtr = New>(network->GetDeviceId(), internalNodeName); + break; + case PrimitiveOpType::LogPlus: + computationNodePtr = New>(network->GetDeviceId(), internalNodeName); + break; + case PrimitiveOpType::Minus: + computationNodePtr = New>(network->GetDeviceId(), internalNodeName); + break; + case PrimitiveOpType::ElementTimes: + computationNodePtr = New>(network->GetDeviceId(), internalNodeName); + break; + case PrimitiveOpType::Equal: + computationNodePtr = New>(network->GetDeviceId(), internalNodeName); + break; + case PrimitiveOpType::NotEqual: + computationNodePtr = New>(network->GetDeviceId(), internalNodeName); + break; + case PrimitiveOpType::Less: + computationNodePtr = New>(network->GetDeviceId(), internalNodeName); + break; + case PrimitiveOpType::LessEqual: + computationNodePtr = New>(network->GetDeviceId(), internalNodeName); + break; + case PrimitiveOpType::Greater: + computationNodePtr = New>(network->GetDeviceId(), internalNodeName); + break; + case PrimitiveOpType::GreaterEqual: + computationNodePtr = New>(network->GetDeviceId(), internalNodeName); + break; + case PrimitiveOpType::Times: + { + size_t outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value(); + auto inferInputRankToMap = functionConfig[PrimitiveFunction::AttributeNameInferInputRankToMap].Value(); + computationNodePtr = New>(network->GetDeviceId(), internalNodeName, outputRank, inferInputRankToMap); + break; + } + case PrimitiveOpType::TransposeTimes: + { + size_t outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value(); + computationNodePtr = New>(network->GetDeviceId(), internalNodeName, outputRank); + break; + } + case PrimitiveOpType::Convolution: + { + NDShape outputMapCount, kernelShape; + std::tie(outputMapCount, kernelShape) = GetConvolutionOutputMapCountAndKernelShape(functionInputs[0].Shape(), functionInputs[1].Shape()); + auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value(); + auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value(); + auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value(); + auto sharing = AsVector(functionConfig[PrimitiveFunction::AttributeNameSharing].Value>()); + auto autoPadding = AsVector(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value>()); + auto transpose = functionConfig[PrimitiveFunction::AttributeNameTranspose].Value(); + auto maxTempMemSizeInSamples = functionConfig[PrimitiveFunction::AttributeNameMaxTempMemSizeInSamples].Value(); + computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsTensorShape(kernelShape), AsTensorShape(outputMapCount), AsTensorShape(strides), sharing, autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), transpose, ImageLayoutKind::CHW, maxTempMemSizeInSamples); + break; + } + case PrimitiveOpType::Logistic: + computationNodePtr = New>(network->GetDeviceId(), internalNodeName); + break; + case PrimitiveOpType::SquaredError: + computationNodePtr = New>(network->GetDeviceId(), internalNodeName); + break; + case PrimitiveOpType::CrossEntropyWithSoftmax: + computationNodePtr = New>(network->GetDeviceId(), internalNodeName); + break; + case PrimitiveOpType::ClassificationError: + computationNodePtr = New>(network->GetDeviceId(), internalNodeName); + break; + case PrimitiveOpType::PastValue: + case PrimitiveOpType::FutureValue: + { + Variable inputOperandVar = functionInputs[0]; + Variable initialStateVar = functionInputs[1]; + + size_t offset = primitiveFunction->Attributes()[PrimitiveFunction::AttributeNameOffset].Value(); + if (op == PrimitiveOpType::PastValue) + computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsTensorShape(inputOperandVar.Shape()), offset); + else + computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsTensorShape(inputOperandVar.Shape()), offset); + + break; + } + case PrimitiveOpType::ReduceElements: + { + auto reductionAxis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value(); + auto reductionOpName = functionConfig[PrimitiveFunction::AttributeNameReductionOpName].Value(); + computationNodePtr = New>(network->GetDeviceId(), internalNodeName, reductionOpName, AsCNTKInternalAxisIdx(reductionAxis)); + break; + } + case PrimitiveOpType::BatchNormalization: + { + auto spatial = functionConfig[PrimitiveFunction::AttributeNameSpatial].Value(); + auto normalizationTimeConstant = functionConfig[PrimitiveFunction::AttributeNameNormalizationTimeConstant].Value(); + auto blendTimeConstant = functionConfig[PrimitiveFunction::AttributeNameBlendTimeConstant].Value(); + auto epsilon = functionConfig[PrimitiveFunction::AttributeNameEpsilon].Value(); + auto useCuDNNEngine = functionConfig[PrimitiveFunction::AttributeNameUseCuDNNEngine].Value(); + + computationNodePtr = New>(network->GetDeviceId(), internalNodeName, spatial, normalizationTimeConstant, blendTimeConstant, epsilon, !useCuDNNEngine, ImageLayoutKind::CHW); + break; + } + case PrimitiveOpType::Combine: + // This operation is just a no-op and is a means to combine multiple functions to create a single Function + // whose outputs are a union of the outputs of the Functions being combined. + computationNodePtr = variableToNodeMap[variable]; + break; + case PrimitiveOpType::PackedIndex: + computationNodePtr = New>(network->GetDeviceId(), internalNodeName); + break; + case PrimitiveOpType::GatherPacked: + computationNodePtr = New>(network->GetDeviceId(), internalNodeName); + break; + case PrimitiveOpType::ScatterPacked: + computationNodePtr = New>(network->GetDeviceId(), internalNodeName); + break; + case PrimitiveOpType::Clip: + computationNodePtr = New>(network->GetDeviceId(), internalNodeName); + break; + case PrimitiveOpType::Select: + computationNodePtr = New>(network->GetDeviceId(), internalNodeName); + break; + case PrimitiveOpType::Splice: + { + Axis spliceAxis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value(); + computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsCNTKInternalAxisIdx(spliceAxis)); + break; + } + case PrimitiveOpType::OptimizedRNNStack: + { + auto bidirectional = functionConfig[PrimitiveFunction::AttributeNameBidirectional].Value(); + auto numLayers = functionConfig[PrimitiveFunction::AttributeNameNumLayers].Value(); + auto hiddenSize = functionConfig[PrimitiveFunction::AttributeNameHiddenSize].Value(); + auto recurrentOp = functionConfig[PrimitiveFunction::AttributeNameRecurrentOp].Value(); + + computationNodePtr = New>(network->GetDeviceId(), internalNodeName, bidirectional, numLayers, hiddenSize, recurrentOp); + break; + } + case PrimitiveOpType::ReconcileDynamicAxis: + { + computationNodePtr = New>(network->GetDeviceId(), internalNodeName); + break; + } + case PrimitiveOpType::LogSoftmax: + { + //This can be implemented as x => x - ReduceLogSum(x). How to do this here? + computationNodePtr = New>(network->GetDeviceId(), internalNodeName); + break; + } + default: + LogicError("Specified op %S not yet supported", PrimitiveOpTypeName(op).c_str()); + break; + } + + std::vector inputNodesBasePtrs; + for (auto inputNode : inputNodes) + inputNodesBasePtrs.push_back(inputNode); + + // Let's reorder inputNodesBasePtrs properly since the ordering of inputs of CNTK internal ComputationNode may be different from the PrimitiveFunction inputs ordering + ReorderAsCNTKComputationNodeInputs(op, inputNodesBasePtrs); + if (computationNodePtr->Is()) + { + auto computationNodeExpectedInputCount = computationNodePtr->As()->GetExpectedNumInputs(); + if (computationNodeExpectedInputCount != inputNodesBasePtrs.size()) + LogicError("Input count mismatch: The Primitive function for op %S has %d inputs while the corresponding ComputationNode has %d inputs", + PrimitiveOpTypeName(op).c_str(), + (int)inputNodesBasePtrs.size(), + (int)computationNodeExpectedInputCount); + } + + network->AddNodeToNetAndAttachInputs(computationNodePtr, inputNodesBasePtrs); + + return computationNodePtr; + } + + template + /*static*/ ComputationNodeBasePtr CompositeFunction::GetOutputVariableNode(const Variable& variable, + Microsoft::MSR::CNTK::ComputationNetworkPtr& network, + ComputationNetworkBuilder& builder, + std::unordered_map& variableToNodeMap, + std::unordered_map& isVariableRootMap) + { + assert(variable.IsOutput()); + + Function* function = variable.Owner().get(); + ComputationNodeBasePtr computationNodePtr; + if (dynamic_cast(function)) + { + PrimitiveFunction* primitiveFunction = dynamic_cast(function); + PrimitiveOpType op = primitiveFunction->OpType(); + auto& functionInputs = primitiveFunction->m_inputs; + + DataType nonConstInputDataType = DataType::Unknown; + for (auto& inputVar : functionInputs) + { + if (!inputVar.IsConstant() && (inputVar.GetDataType() != DataType::Unknown)) + { + nonConstInputDataType = inputVar.GetDataType(); + break; + } + } + + // Create the nodes corresponding to the inputs + std::vector>> inputNodes; + for (auto& inputVar : functionInputs) + { + // If the inputVar is a constant and not the right DataType let's coerce it to the right type + if (inputVar.IsConstant() && (nonConstInputDataType != DataType::Unknown) && (inputVar.GetDataType() != nonConstInputDataType)) + { + auto originalConstantValue = Constant(inputVar).Value(); + auto constantValueCPU = originalConstantValue->DeepClone(DeviceDescriptor::CPUDevice(), true); + NDArrayViewPtr newConstantValue = CloneAsDataType(constantValueCPU, nonConstInputDataType, true); + inputVar = Constant(newConstantValue->DeepClone(originalConstantValue->Device(), originalConstantValue->IsReadOnly()), inputVar.Name()); + } + + auto baseNodePtr = GetNode(inputVar, network, builder, variableToNodeMap, isVariableRootMap); + inputNodes.push_back((baseNodePtr != nullptr) ? baseNodePtr->template As>()->shared_from_this() : nullptr); + } + + computationNodePtr = CreateComputationNode(variable, primitiveFunction, inputNodes, network, variableToNodeMap); + if (op != PrimitiveOpType::Combine) + { + for (auto inputVar : functionInputs) + isVariableRootMap[inputVar] = false; + } + } + else + LogicError("User defined Functions are currently unsupported!"); + + return computationNodePtr; + } + + template + ComputationNetworkPtr CompositeFunction::GetComputationNetwork(const DeviceDescriptor& device, const std::unordered_set& backpropRoots, bool allocateNetworkMatrices) + { + if (m_computationNetwork != nullptr) + { + // TODO: We should either invalidate and readapt the network if he backpropRoots change compared to what was specified when the network + // was last constructed, to just recreate a new network. + // For now just disallow changing the backpropRoots after the network is created + if (!backpropRoots.empty() && (m_currentBackpropRoots != backpropRoots)) + LogicError("Changing backprop roots across different Forward calls on a CNTK composite Function is currently unsupported"); + + // TODO: Support changing the device across different invocations of the forward method on a Function instance + if (AsDeviceDescriptor(m_computationNetwork->GetDeviceId()) != device) + LogicError("Changing device across different Forward calls on a CNTK composite Function is currently unsupported"); + + } + else + { + m_computationNetwork = std::make_shared(AsCNTKImplDeviceId(device)); + + ComputationNetworkBuilder builder(*m_computationNetwork); + + // TODO: We currently only support one backprop root + if (backpropRoots.size() > 1) + LogicError("More than one backprop roots is currently unsupported"); + + auto placeholders = Placeholders(); + if (!placeholders.empty()) + InvalidArgument("All placeholders of a Function must be bound before performing a Forward computation on the Function!"); + + // Now recursively create the network in a top-down fashion + auto rootFunction = RootFunction(); + auto rootFunctionOutputs = rootFunction->Outputs(); + for (auto rootOutput : rootFunctionOutputs) + GetNode(rootOutput, m_computationNetwork, builder, m_variableToNodeMap, m_isVariableRootMap); + + // If any of the function outputs is not a root node, we need to explicitly add it to the 'output' group of the ComputationNetwork + for (auto rootOutput : rootFunctionOutputs) + { + if (!m_isVariableRootMap[rootOutput]) + m_computationNetwork->AddToNodeGroup(L"output", m_variableToNodeMap[rootOutput]); + } + + m_currentBackpropRoots = backpropRoots; + + // In case of recurrence, the inputs of some of the ComputationNodes are not attached due to cycles. + // Now attach those after we have created all ComputationNodes in the network + for (auto varNodePair : m_variableToNodeMap) + { + auto& currentComputationNode = varNodePair.second; + auto& currentComputationNodeInputs = currentComputationNode->GetInputs(); + auto& currentVar = varNodePair.first; + + if (std::find(currentComputationNodeInputs.begin(), currentComputationNodeInputs.end(), nullptr) != currentComputationNodeInputs.end()) + { + // This ComputationNode has at least one null input which now needs to be properly attached + + const PrimitiveFunction* primitiveFunc = dynamic_cast(currentVar.Owner().get()); + + // Let's reorder properly since the ordering of inputs of CNTK internal ComputationNode may be different from the PrimitiveFunction inputs ordering + auto inputVars = primitiveFunc->Inputs(); + ReorderAsCNTKComputationNodeInputs(primitiveFunc->OpType(), inputVars); + inputVars.resize(currentComputationNode->GetNumInputs()); + + std::vector inputNodesBasePtrs; + for (auto inputVar : inputVars) + inputNodesBasePtrs.push_back(m_variableToNodeMap[inputVar]); + + currentComputationNode->AttachInputs(inputNodesBasePtrs); + } + } + + m_computationNetwork->SetTraceLevel(Internal::GetComputationNetworkTraceLevel()); + m_computationNetwork->CompileNetwork(); + + // Verify that the shapes of the output Variables that we computed match the corresponding nodes in the ComputationNetwork + for (auto varNodePair : m_variableToNodeMap) + { + if (varNodePair.first.IsOutput()) + { + auto outputVar = varNodePair.first; + auto computationNodePtr = m_variableToNodeMap[outputVar]; + auto outputShape = outputVar.Shape(); + auto computationNodeSampleLayout = computationNodePtr->GetSampleLayout(); + if (((outputShape.Rank() == 0) && (computationNodeSampleLayout[0] != 1)) || + ((outputShape.Rank() != 0) && (computationNodeSampleLayout != AsTensorViewShape(outputShape)) && (computationNodeSampleLayout != AsTensorShape(outputShape)))) + { + LogicError("The output Variable shape %S does not match the SampleLayout shape %s of the corresponding ComputationNode in the network", outputShape.AsString().c_str(), ((std::string)computationNodeSampleLayout).c_str()); + } + } + } + + // Record the timestamps of Parameter values + assert(m_lastRecordedParameterValueTimeStamps.empty()); + auto functionParameters = Parameters(); + for (auto parameter : functionParameters) + m_lastRecordedParameterValueTimeStamps.insert({ parameter, parameter.CurrentValueTimeStamp() }); + } + + + if (!m_networkMatricesAllocated && allocateNetworkMatrices) + { + ComputationNodeBasePtr backpropRootNode; + + // Now recursively traverse the network in a top-down fashion + auto rootFunction = RootFunction(); + auto rootFunctionOutputs = rootFunction->Outputs(); + std::vector forwardRootNodes; + for (auto rootOutput : rootFunctionOutputs) + { + auto currentRootNode = m_variableToNodeMap[rootOutput]; + forwardRootNodes.push_back(currentRootNode); + + if (m_currentBackpropRoots.find(rootOutput) != m_currentBackpropRoots.end()) + backpropRootNode = currentRootNode; + } + + m_computationNetwork->AllocateAllMatrices(forwardRootNodes, {}, backpropRootNode); + m_networkMatricesAllocated = allocateNetworkMatrices; + } + + return m_computationNetwork; + } + + template + /*static*/ std::pair>, MBLayoutPtr> CompositeFunction::GetCNTKImplMatrixAndMBLayoutFromValueObject(Variable var, const ValuePtr& value) + { + if (var.GetDataType() != value->GetDataType()) + LogicError("The Variable's DataType %s does not match the corresponding Value's DataType %s", DataTypeName(var.GetDataType()), DataTypeName(value->GetDataType())); + + if (AsDataType() != value->GetDataType()) + LogicError("The specified ElementType %s does not match the DataType %s", typeid(ElementType).name(), DataTypeName(value->GetDataType())); + + // TODO: Is supplying dense data for an Input variable tagged as sparse, a fatal error? + if (IsSparseInput(var) && !value->IsSparse()) + InvalidArgument("Dense input data supplied for a sparse input Variable"); + + if (IsSparseInput(var) && (value->GetStorageFormat() != StorageFormat::SparseCSC)) + InvalidArgument("Sparse Input data must be in SparseCSC format"); + + auto varShape = var.Shape(); + auto valueShape = value->Shape(); + if (valueShape.Rank() < varShape.Rank()) + InvalidArgument("Value's rank should be >= the Variable's rank"); + + size_t maxAddionalValueAxes = std::max(2, var.DynamicAxes().size()); + if (valueShape.Rank() > (varShape.Rank() + maxAddionalValueAxes)) + InvalidArgument("Value rank should be larger than the Variable%S rank at most by number of dynamic axes", ParanthesizedName(var.Name()).c_str()); + + if (valueShape.SubShape(0, varShape.Rank()) != varShape) + { + InvalidArgument("The %s dimensions of the Value shape %S do not match the shape of the variable %S that it corresponds to!", + Internal::IsReversingTensorShapesInErrorMessagesEnabled() ? "trailing" : "leading", + AsStringForErrorReporting(valueShape).c_str(), + AsStringForErrorReporting(varShape).c_str()); + } + + if (var.DynamicAxes().empty()) + return{ value->Data()->GetMatrix(), nullptr }; + + if (var.DynamicAxes().size() > 2) + LogicError("More than 2 dynamic axis for a variable is currently unsupported"); + + auto mask = value->Mask(); + if ((mask != nullptr) && ((varShape.Rank() + mask->Shape().Rank()) != valueShape.Rank())) + InvalidArgument("Invalid Value object; the sum of the rank of the mask and data does not equal the Variable's rank + number of dynamic axes"); + + auto getNumTimeStepsAndSequencesFunc = [](const NDShape& maskShape) { + size_t maxNumTimeSteps = 1; + size_t numSequences = 1; + if (maskShape.Rank() > 0) + maxNumTimeSteps = maskShape[0]; + + if (maskShape.Rank() > 1) + numSequences = maskShape[1]; + + return std::pair(maxNumTimeSteps, numSequences); + }; + + size_t maxNumTimeSteps, numSequences; + std::tie(maxNumTimeSteps, numSequences) = getNumTimeStepsAndSequencesFunc(valueShape.SubShape(varShape.Rank())); + + auto getSequenceStartsAndLengthsFunc = [&getNumTimeStepsAndSequencesFunc](const NDMaskPtr& mask, std::vector& sequenceBeginIndices, std::vector& sequenceLengths) { + auto cpuMask = mask; + if (mask->Device() != DeviceDescriptor::CPUDevice()) + cpuMask = mask->DeepClone(DeviceDescriptor::CPUDevice()); + + const MaskKind* maskBuffer = cpuMask->DataBuffer(); + size_t maxNumTimeSteps, numSequences; + std::tie(maxNumTimeSteps, numSequences) = getNumTimeStepsAndSequencesFunc(mask->Shape()); + + for (size_t i = 0; i < numSequences; ++i) + { + MaskKind firstMaskEntry = maskBuffer[i * maxNumTimeSteps]; + if (firstMaskEntry == MaskKind::SequenceBegin) + sequenceBeginIndices[i] = 0; + else if (firstMaskEntry == MaskKind::Valid) + sequenceBeginIndices[i] = Microsoft::MSR::CNTK::SentinelValueIndicatingUnspecifedSequenceBeginIdx; + else + LogicError("The first entry of a mask should be Valid or SequenceBegin"); + + size_t currentSequenceLength = 1; + bool currentSequenceEndAlreadyFound = false; + for (size_t j = 1; j < maxNumTimeSteps; ++j) + { + if (maskBuffer[(i * maxNumTimeSteps) + j] == MaskKind::Invalid) + currentSequenceEndAlreadyFound = true; + else + { + if (currentSequenceEndAlreadyFound) + InvalidArgument("Invalid Value object; only trailing steps of a sequence can be masked"); + + currentSequenceLength++; + } + } + + sequenceLengths[i] = currentSequenceLength; + } + }; + + if ((numSequences == 1) || (maxNumTimeSteps == 1)) + { + // The data need not be shuffled + std::shared_ptr> matrixData = value->Data()->GetMatrix(varShape.Rank()); + auto layout = std::make_shared(); + if (!mask) + { + if (maxNumTimeSteps == 1) + layout->InitAsFrameMode(numSequences); + else + { + layout->Init(numSequences, maxNumTimeSteps); + layout->AddSequence(0, 0, 0, maxNumTimeSteps); + } + } + else + { + layout->Init(numSequences, maxNumTimeSteps); + + std::vector sequenceBeginIndices(numSequences, 0); + std::vector sequenceLengths(numSequences, maxNumTimeSteps); + getSequenceStartsAndLengthsFunc(mask, sequenceBeginIndices, sequenceLengths); + + for (size_t i = 0; i < numSequences; ++i) + layout->AddSequence(i, i, sequenceBeginIndices[i], sequenceLengths[i]); + } + + return{ matrixData , layout}; + } + else + { + std::vector sequenceBeginIndices(numSequences, 0); + std::vector sequenceLengths(numSequences, maxNumTimeSteps); + if (mask != nullptr) + getSequenceStartsAndLengthsFunc(mask, sequenceBeginIndices, sequenceLengths); + + bool hasTruncatedSequences = std::find_if(sequenceBeginIndices.begin(), sequenceBeginIndices.end(), [](const int& val) { return (val < 0); }) != sequenceBeginIndices.end(); + + auto layout = std::make_shared(); + std::vector> placement; + if (!hasTruncatedSequences) + { + std::vector sequences; + for (size_t i = 0; i < numSequences; ++i) + sequences.push_back({ i, SIZE_MAX, sequenceBeginIndices[i], sequenceLengths[i] }); + + std::vector rowAllocations; + layout->InitAsPackedSequences(sequences, placement, rowAllocations); + } + else + { + layout->Init(numSequences, maxNumTimeSteps); + + // We cannot pack as some of the sequences are truncated and thus all sequences have to be + // kept in their original parallel streams + placement.resize(numSequences); + for (size_t i = 0; i < numSequences; ++i) + { + layout->AddSequence(i, i, sequenceBeginIndices[i], sequenceLengths[i]); + + // Add the gap if there is one + if (sequenceLengths[i] < maxNumTimeSteps) + layout->AddSequence(GAP_SEQUENCE_ID, i, sequenceLengths[i], maxNumTimeSteps); + + placement[i] = std::make_pair(i, 0); + } + } + + if (maxNumTimeSteps != layout->GetNumTimeSteps()) + LogicError("The number of time steps in the packed MBLayout does not match the longest sequence's length in the Value object"); + + if (numSequences != layout->GetNumSequences()) + LogicError("The number of sequences in the packed MBLayout does not match the sequence count in the Value object"); + + // The data needs to be rearranged since CNTK requires sequences to be interleaved across timesteps + // Now generate the gather indices + auto matrixData = std::make_shared>(varShape.TotalSize(), + layout->GetNumCols(), + AsCNTKImplDeviceId(value->Device()), + value->IsSparse() ? MatrixType::SPARSE : MatrixType::DENSE, + AsCNTKImplMatrixFormat(value->GetStorageFormat())); + + std::vector sequencesShorterThanLongestSequence; + for (size_t i = 0; i < numSequences; ++i) + if (sequenceLengths[i] != maxNumTimeSteps) + sequencesShorterThanLongestSequence.push_back(i); + + // Set the source location for all gaps to be the last step of the first sequence that is shorter than the longest sequence in the batch + size_t sourceColIdxForInvalidColumns = sequencesShorterThanLongestSequence.empty() ? 0 : (((sequencesShorterThanLongestSequence[0] + 1) * maxNumTimeSteps) - 1); + std::vector gatherIndicesVector(layout->GetNumCols(), (ElementType)sourceColIdxForInvalidColumns); + for (size_t i = 0; i < numSequences; ++i) + { + size_t targetParallelStreamIdx = placement[i].first; + size_t targetStartIdxInParallelStream = placement[i].second; + for (size_t j = 0; j < sequenceLengths[i]; ++j) + gatherIndicesVector[((targetStartIdxInParallelStream + j) * layout->GetNumParallelSequences()) + targetParallelStreamIdx] = (ElementType)((i * maxNumTimeSteps) + j); + } + + auto gatherIdxMatrix = std::make_shared>(1, layout->GetNumCols(), gatherIndicesVector.data(), AsCNTKImplDeviceId(value->Device())); + matrixData->DoGatherColumnsOf(0, *gatherIdxMatrix, *(value->Data()->GetMatrix(varShape.Rank())), 1); + return{ matrixData, layout }; + } + } + + template + /*static*/ ValuePtr CompositeFunction::GetValueObjectFromCNTKImplMatrixAndMBLayout(const NDShape& sampleShape, const Matrix& matrix, const MBLayoutPtr& layout, bool readOnly /*= true*/) + { + NDShape valueDataShape = sampleShape; + + size_t maxNumTimeSteps = 1; + size_t numSequences = 1; + if (layout != nullptr) + { + maxNumTimeSteps = layout->GetNumTimeSteps(); + numSequences = layout->GetNumSequences(); + valueDataShape = valueDataShape.AppendShape({ maxNumTimeSteps, numSequences }); + } + + auto createMaskFunc = [](const MBLayoutPtr& layout, const DeviceDescriptor& device, std::vector& sequencesShorterThanLongestSequence) { + std::vector sequenceBeginFlags; + std::vector sequenceLengths; + sequencesShorterThanLongestSequence.clear(); + + size_t maxNumTimeSteps = layout->GetNumTimeSteps(); + size_t numSequences = layout->GetNumSequences(); + auto& layoutSequences = layout->GetAllSequences(); + + size_t sequenceIdx = 0; + bool allSequencesStartInThisMB = true; + bool allSequencesSameLength = true; + for (auto sequenceInfo : layoutSequences) + { + if (sequenceInfo.seqId != GAP_SEQUENCE_ID) + { + auto currentSequenceBeginIdx = std::max(0, sequenceInfo.tBegin); + auto currentSequenceEndIdx = std::min(maxNumTimeSteps, sequenceInfo.tEnd); + auto currentSequenceLength = (currentSequenceEndIdx - currentSequenceBeginIdx); + auto isCurrentSequenceBeginningInsideThisMB = sequenceInfo.tBegin >= 0; + + allSequencesStartInThisMB = allSequencesStartInThisMB && isCurrentSequenceBeginningInsideThisMB; + allSequencesSameLength = allSequencesSameLength && (currentSequenceLength == maxNumTimeSteps); + + sequenceBeginFlags.push_back(isCurrentSequenceBeginningInsideThisMB); + sequenceLengths.push_back(currentSequenceLength); + + if (currentSequenceLength != maxNumTimeSteps) + sequencesShorterThanLongestSequence.push_back(sequenceIdx); + + sequenceIdx++; + } + } + + if (!allSequencesStartInThisMB && (numSequences != layout->GetNumParallelSequences())) + LogicError("Cannot create an unpacked Value object from packed data where one or more sequences are truncated"); + + bool maskNeeded = !allSequencesSameLength || !allSequencesStartInThisMB; + + NDMaskPtr mask; + if (maskNeeded) + { + mask = MakeSharedObject(NDShape({ maxNumTimeSteps, numSequences }), DeviceDescriptor::CPUDevice()); + for (size_t i = 0; i < numSequences; ++i) + if (sequenceBeginFlags[i]) + mask->MarkSequenceBegin({0, i}); + + for (auto shortSequenceIdx : sequencesShorterThanLongestSequence) + mask->InvalidateSection({ sequenceLengths[shortSequenceIdx], shortSequenceIdx }, { NDShape::InferredDimension, 1 }); + } + + return mask; + }; + + // No data shuffling needed if no layout or the layout has just one time-step or just one sequence + std::vector sequencesShorterThanLongestSequence; + if ((maxNumTimeSteps == 1) || (numSequences == 1)) + { + // Just create a view over the existing matrix itself + auto tensorView = new TensorView(std::make_shared>(matrix.AsReference()), AsTensorViewShape(valueDataShape)); + auto data = MakeSharedObject(AsDataType(), AsDeviceDescriptor(matrix.GetDeviceId()), AsStorageFormat(matrix.GetFormat()), valueDataShape, readOnly, tensorView); + if (layout == nullptr) + return MakeSharedObject(data); + else + { + auto mask = createMaskFunc(layout, AsDeviceDescriptor(matrix.GetDeviceId()), sequencesShorterThanLongestSequence); + return MakeSharedObject(data, mask); + } + } + + if (layout->GetNumCols() != matrix.GetNumCols()) + LogicError("Bad MBLayout: The number of columns in the MBLayout does not match the number of columns in the data matrix!"); + + // Reshuffle to data to unpack and uninterleave the CNTK form packed data + // Now generate the scatter indices + auto shuffledMatrixData = std::make_shared>(matrix.GetNumRows(), maxNumTimeSteps * numSequences, matrix.GetDeviceId(), matrix.GetMatrixType(), matrix.GetFormat()); + auto mask = createMaskFunc(layout, AsDeviceDescriptor(matrix.GetDeviceId()), sequencesShorterThanLongestSequence); + + // Set the target location of all gaps to be the last step of the first sequence that is shorter than the longest sequence in the batch + size_t targetColIdxForInvalidColumns = sequencesShorterThanLongestSequence.empty() ? 0 : (((sequencesShorterThanLongestSequence[0] + 1) * maxNumTimeSteps) - 1); + std::vector scatterIndicesVector(layout->GetNumCols(), (ElementType)targetColIdxForInvalidColumns); + + size_t i = 0; + auto& layoutSequences = layout->GetAllSequences(); + for (auto sequenceInfo : layoutSequences) + { + if (sequenceInfo.seqId != GAP_SEQUENCE_ID) + { + size_t targetParallelStreamIdx = sequenceInfo.s; + auto currentSequenceBeginIdx = std::max(0, sequenceInfo.tBegin); + auto currentSequenceEndIdx = std::min(maxNumTimeSteps, sequenceInfo.tEnd); + size_t currentSequenceLength = (currentSequenceEndIdx - currentSequenceBeginIdx); + + for (size_t j = 0; j < currentSequenceLength; ++j) + scatterIndicesVector[((currentSequenceBeginIdx + j) * layout->GetNumParallelSequences()) + targetParallelStreamIdx] = (ElementType)((i * maxNumTimeSteps) + j); + + i++; + } + } + + auto scatterIdxMatrix = std::make_shared>(1, layout->GetNumCols(), scatterIndicesVector.data(), matrix.GetDeviceId()); + shuffledMatrixData->DoScatterColumnsOf(0, *scatterIdxMatrix, matrix, 1); + + auto tensorView = new TensorView(shuffledMatrixData, AsTensorViewShape(valueDataShape)); + auto data = MakeSharedObject(AsDataType(), AsDeviceDescriptor(matrix.GetDeviceId()), AsStorageFormat(shuffledMatrixData->GetFormat()), valueDataShape, readOnly, tensorView); + return MakeSharedObject(data, mask); + } + + template + /*static*/ ValuePtr CompositeFunction::GetValueObjectFromCNTKImplMatrixAndMBLayout(Variable var, const Matrix& matrix, const MBLayoutPtr& layout, bool readOnly /*= true*/) + { + if (var.DynamicAxes().size() > 2) + LogicError("More than 2 dynamic axis for a variable is currently unsupported"); + + if (AsDataType() != var.GetDataType()) + LogicError("The specified ElementType %s does not match the DataType %s", typeid(ElementType).name(), DataTypeName(var.GetDataType())); + + if ((layout != nullptr) && (matrix.GetNumRows() != var.Shape().TotalSize())) + LogicError("Unexpected matrix layout: The number of rows in the matrix does not match the sample size of the Variable"); + + return GetValueObjectFromCNTKImplMatrixAndMBLayout(var.Shape(), matrix, layout, readOnly); + } + + template + /*static*/ void CompositeFunction::PopulateComputationNodeValue(const std::pair& variableValue, ComputationNodeBasePtr& computationNode) + { + std::pair>, MBLayoutPtr> CNTKMatrixAndMBLayout; + auto packedValue = dynamic_cast(variableValue.second.get()); + if (packedValue) + CNTKMatrixAndMBLayout = packedValue->PackedData(); + else + CNTKMatrixAndMBLayout = GetCNTKImplMatrixAndMBLayoutFromValueObject(variableValue.first, variableValue.second); + + MBLayoutPtr layout = CNTKMatrixAndMBLayout.second; + + auto& nodeData = computationNode->As>()->Value(); + + // Switch the node matrix to the right matrix type + nodeData.AssignValuesOf(*CNTKMatrixAndMBLayout.first); + computationNode->GetMBLayout()->CopyFrom(layout); + } + + void CompositeFunction::PopulateNetworkInputs(const std::unordered_map& arguments) + { + std::vector inputNodes; + for (auto argumentValuePair : arguments) + { + auto argument = argumentValuePair.first; + auto argumentComputationNode = m_variableToNodeMap[argument]; + assert(argumentComputationNode); + inputNodes.push_back(argumentComputationNode); + + ValuePtr argumentValue = arguments.at(argument); + + MBLayoutPtr layout; + switch (argumentValue->GetDataType()) + { + case DataType::Float: + PopulateComputationNodeValue({ argument, argumentValue }, argumentComputationNode); + break; + case DataType::Double: + PopulateComputationNodeValue({ argument, argumentValue }, argumentComputationNode); + break; + default: + LogicError("Unsupported DataType %s", DataTypeName(argumentValue->GetDataType())); + break; + } + } + + m_computationNetwork->BumpEvalTimeStamp(inputNodes); + } + + template + /*static*/ void CompositeFunction::PopulateComputationNodeGradient(const std::pair& variableGradient, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode) + { + std::pair>, MBLayoutPtr> CNTKMatrixAndMBLayout; + auto packedValue = dynamic_cast(variableGradient.second.get()); + if (packedValue) + CNTKMatrixAndMBLayout = packedValue->PackedData(); + else + CNTKMatrixAndMBLayout = GetCNTKImplMatrixAndMBLayoutFromValueObject(variableGradient.first, variableGradient.second); + + MBLayoutPtr layout = CNTKMatrixAndMBLayout.second; + auto nodeLayout = computationNode->GetMBLayout(); + if (((layout == nullptr) != (nodeLayout == nullptr)) || ((layout != nullptr) && (*layout != *nodeLayout))) + InvalidArgument("The layout of the specified gradient Value is incompatible with the layout of the corresponding Variable computed during Forward call"); + computationNode->As>()->AssignGradient(*CNTKMatrixAndMBLayout.first); + } + + // Assign the supplied gradients corresponding to the root(s) of the network to be backpropagated through the graph + void CompositeFunction::PopulateNetworkGradients(const std::unordered_map& gradients) + { + auto functionOutputs = this->Outputs(); + for (auto gradientVarValuePair : gradients) + { + // Only gradients for roots of the function can be specified + if (std::find(functionOutputs.begin(), functionOutputs.end(), gradientVarValuePair.first) == functionOutputs.end()) + InvalidArgument("Gradients cannot be specified for a Variable that is not an Output of the Function"); + + auto outputComputationNode = m_variableToNodeMap[gradientVarValuePair.first]; + ValuePtr gradientValue = gradientVarValuePair.second; + + switch (gradientValue->GetDataType()) + { + case DataType::Float: + PopulateComputationNodeGradient(gradientVarValuePair, outputComputationNode); + break; + case DataType::Double: + PopulateComputationNodeGradient(gradientVarValuePair, outputComputationNode); + break; + default: + LogicError("Unsupported DataType %s", DataTypeName(gradientValue->GetDataType())); + break; + } + } + } + + static NDShape GetValueShape(const Variable& var, const ComputationNodeBasePtr& computationNodePtr) + { + size_t outputValueNumAxes = var.Shape().Rank(); + + // Add the batch and dynamic axes if needed + if (computationNodePtr->GetMBLayout() != nullptr) + outputValueNumAxes += 2; + + std::vector outputShapeDims(outputValueNumAxes); + for (size_t i = 0; i < var.Shape().Rank(); ++i) + outputShapeDims[i] = computationNodePtr->GetSampleLayout().GetDim(i); + + if (computationNodePtr->GetMBLayout() != nullptr) + { + outputShapeDims[var.Shape().Rank()] = computationNodePtr->GetMBLayout()->GetNumTimeSteps(); + outputShapeDims[var.Shape().Rank() + 1] = computationNodePtr->GetMBLayout()->GetNumSequences(); + } + + return NDShape(outputShapeDims); + } + + /*static*/ void CompositeFunction::GetNodeOutputOrGradient(Variable var, ValuePtr& varValue, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode, bool getGradient) + { + auto valueShape = GetValueShape(var, computationNode); + if (varValue != nullptr) + { + // TODO: The shape of the specified output Value object must match the actual output shape + if ((varValue->Shape() != valueShape) && (AsTensorShape(varValue->Shape()) != AsTensorShape(valueShape))) + InvalidArgument("The shape %S of the specified Value object for %s does not match the actual shape %S", AsStringForErrorReporting(varValue->Shape()).c_str(), getGradient ? "gradient" : "output", AsStringForErrorReporting(valueShape).c_str()); + } + + ValuePtr nodeValue; + auto layout = computationNode->GetMBLayout(); + switch (var.GetDataType()) + { + case DataType::Float: + { + auto& matrix = getGradient ? computationNode->As>()->Gradient() : computationNode->As>()->Value(); + if (varValue == nullptr) + nodeValue = MakeSharedObject(var.Shape(), std::make_shared>(matrix.AsReference()), layout, /*readOnly =*/ false); + else + nodeValue = GetValueObjectFromCNTKImplMatrixAndMBLayout(var, matrix, layout); + break; + } + case DataType::Double: + { + auto& matrix = getGradient ? computationNode->As>()->Gradient() : computationNode->As>()->Value(); + if (varValue == nullptr) + nodeValue = MakeSharedObject(var.Shape(), std::make_shared>(matrix.AsReference()), layout, /*readOnly =*/ false); + else + nodeValue = GetValueObjectFromCNTKImplMatrixAndMBLayout(var, matrix, layout); + break; + } + default: + LogicError("Unsupported DataType %s", DataTypeName(var.GetDataType())); + break; + } + + if (varValue == nullptr) + varValue = nodeValue; + else + varValue->CopyFrom(*nodeValue); + } + + void CompositeFunction::GetNetworkOutputs(std::unordered_map& outputs) + { + // Now copy the Forward values of output nodes from the network to outputs' Value objects + for (auto outputVarValuePair : outputs) + GetNodeOutputOrGradient(outputVarValuePair.first, outputs[outputVarValuePair.first], m_variableToNodeMap[outputVarValuePair.first], false /*getGradient*/); + } + + void CompositeFunction::GetNetworkGradients(std::unordered_map& gradients) + { + auto networkInputs = this->Inputs(); + // Now copy the gradient values of input nodes of the network to gradients' Value objects + for (auto gradientVarValuePair : gradients) + { + // Only gradients corresponding to inputs of the network can be obtained + if (std::find(networkInputs.begin(), networkInputs.end(), gradientVarValuePair.first) == networkInputs.end()) + InvalidArgument("Backpropagated gradient values can only be obtained for inputs of a Function"); + + // Gradients can only be obtained for parameter variables or input variables that NeedsGradient + if (!gradientVarValuePair.first.NeedsGradient()) + InvalidArgument("Gradient value incorrectly requested for an Output or Constant Variable, or an Input Variable with NeedsGradient setting of false"); + + auto computationNodePtr = m_variableToNodeMap[gradientVarValuePair.first]; + + if (!computationNodePtr->NeedsGradient()) + LogicError("Backpropagated gradient value cannot be read from a ComputationNode that has NeedsGradient set to false"); + + GetNodeOutputOrGradient(gradientVarValuePair.first, gradients[gradientVarValuePair.first], computationNodePtr, true /*getGradient*/); + } + } + + const std::vector& CompositeFunction::GetArgumentDependencies(const Variable& output) + { + assert(output.IsOutput()); + + auto iter = m_perOutputVarArgumentDependencies.find(output); + if (iter != m_perOutputVarArgumentDependencies.end()) + return iter->second; + + auto wrappedComposite = CompositeFunction::Create(output.Owner()); + m_perOutputVarArgumentDependencies[output] = wrappedComposite->Arguments(); + + return m_perOutputVarArgumentDependencies[output]; + } + + /*virtual*/ BackPropStatePtr CompositeFunction::Forward(const std::unordered_map& arguments, + std::unordered_map& outputs, + const DeviceDescriptor& computeDevice, + const std::unordered_set& outputsToRetainBackwardStateFor) + { + // Validate arguments and outputs + if (outputs.empty()) + InvalidArgument("CompositeFunction::Forward: At least one output has to be specified!"); + + // Make sure that the DataType of the variables and corresponding values match + // TODO: We need a better way to determine the ElementType for the network + auto dataType = DataType::Unknown; + for (auto variableValuePair : arguments) + { + if (dataType == DataType::Unknown) + dataType = variableValuePair.first.GetDataType(); + else if (dataType != variableValuePair.first.GetDataType()) + LogicError("CompositeFunction::Forward: The DataType of all arguments of the Function must be same"); + } + + if (dataType == DataType::Unknown) + { + for (auto variableValuePair : outputs) + { + if (dataType == DataType::Unknown) + dataType = variableValuePair.first.GetDataType(); + } + } + + if (dataType == DataType::Float) + GetComputationNetwork(computeDevice, outputsToRetainBackwardStateFor, true); + else if (dataType == DataType::Double) + GetComputationNetwork(computeDevice, outputsToRetainBackwardStateFor, true); + else + InvalidArgument("Unsupported DataType %s", DataTypeName(dataType)); + + std::unordered_set functionOutputs(this->Outputs().begin(), this->Outputs().end()); + std::vector outputsToEvaluate; + std::unordered_set requiredArguments; + for (auto outputVarValuePair : outputs) + { + // Ensure that only a subset of this function's outputs are being asked to be evaluated + if (functionOutputs.find(outputVarValuePair.first) == functionOutputs.end()) + InvalidArgument("Requested output is not an Ouptut of the Function"); + + auto& requiredArgumentsForCurrentOutput = GetArgumentDependencies(outputVarValuePair.first); + requiredArguments.insert(requiredArgumentsForCurrentOutput.begin(), requiredArgumentsForCurrentOutput.end()); + + auto outputComputationNode = m_variableToNodeMap[outputVarValuePair.first]; + outputsToEvaluate.push_back(outputComputationNode); + } + + // TODO: Avoid copying the data when possible + + // We should have argument values supplied for all required argument dependencies for the requested outputs + for (auto requiredArgument : requiredArguments) + { + if (arguments.find(requiredArgument) == arguments.end()) + InvalidArgument("Function::Forward: Required argument's (%S) value that the requested output(s) depend on has not been provided", requiredArgument.Name().c_str()); + } + + // Feed data into the arguments of the network + PopulateNetworkInputs(arguments); + + // Dropout nodes have an implicit input in the form of the random mask that is applied to its explicit input + // This mask is regerated every minibatch and hence dropout nodes with a non-zero dropout rate must me marked outdated + // w.r.t. inputs to force evaluation in each minibatch + list dropoutNodes = m_computationNetwork->GetNodesWithType(OperationNameOf(DropoutNode)); + for (auto& nodeIter : dropoutNodes) + nodeIter->SetEvalTimeStampOutdatedWrtAll(); + + // Bump the timestamp of the parameter nodes whose values have changed + for (auto& paramTimeStampRecord : m_lastRecordedParameterValueTimeStamps) + { + auto parameter = paramTimeStampRecord.first; + auto prevTimeStamp = paramTimeStampRecord.second; + auto newTimeStamp = parameter.CurrentValueTimeStamp(); + if (newTimeStamp > prevTimeStamp) + { + paramTimeStampRecord.second = newTimeStamp; + m_variableToNodeMap[parameter]->BumpEvalTimeStamp(); + } + } + + // The 'outputsToRetainBackwardStateFor' nodes also need to be evaluated if not already specified in 'outputs' + for (auto rootVarForBackprop : outputsToRetainBackwardStateFor) + { + if (functionOutputs.find(rootVarForBackprop) == functionOutputs.end()) + InvalidArgument("Requested outputs to retain backward state for is not an Ouptut of the Function"); + + if (outputs.find(rootVarForBackprop) == outputs.end()) + outputsToEvaluate.push_back(m_variableToNodeMap[rootVarForBackprop]); + } + + // TODO: Verify that values were supplied for all inputs that requested outputs depend on + + ScopedNetworkOperationMode modeGuard(m_computationNetwork, outputsToRetainBackwardStateFor.empty() ? NetworkOperationMode::inferring : NetworkOperationMode::training); + + m_computationNetwork->ForwardProp(outputsToEvaluate); + + GetNetworkOutputs(outputs); + + // TODO: How to deal with the specified 'computeDevice' + Variable evalTimeStampVariable; + if (arguments.empty()) + evalTimeStampVariable = Inputs()[0]; + else + evalTimeStampVariable = arguments.begin()->first; + + return (outputsToRetainBackwardStateFor.size() > 0) ? MakeSharedObject(this->shared_from_this(), computeDevice, std::make_pair(evalTimeStampVariable, m_variableToNodeMap[evalTimeStampVariable]->GetEvalTimeStamp())) : nullptr; + } + + /*virtual*/ void CompositeFunction::Backward(const BackPropStatePtr& state, + const std::unordered_map& rootGradientValues, + std::unordered_map& backPropagatedGradientValuesForInputs) + { + auto backpropState = dynamic_cast(state.get()); + if (backpropState == nullptr) + InvalidArgument("Invalid backprop state specified"); + + // TODO: Support multiple concurrent backprop states + if (backpropState->EvalTimeStamp().second != m_variableToNodeMap[backpropState->EvalTimeStamp().first]->GetEvalTimeStamp()) + LogicError("The specified backprop state specified cannot be used for backpropagation as the Function's internal state was modified by subsequent Forward calls to the function." + "This is not a user error but a shortcoming of the current implementation where multiple independent backprop states are not simultaneously supported"); + + if (rootGradientValues.size() > 1) + LogicError("Currently gradient backprop from only one of the Function Outputs is supported"); + + // TODO: Avoid copying the data when possible + + // Zero all gradients of nodes below the root nodes + for (auto rootGradientVarValuePair : rootGradientValues) + m_computationNetwork->ZeroInputGradients(m_variableToNodeMap[rootGradientVarValuePair.first]); + + // Feed data into the arguments of the network + PopulateNetworkGradients(rootGradientValues); + + // Backpropagate through the network + ScopedNetworkOperationMode modeGuard(m_computationNetwork, NetworkOperationMode::training); + + auto rootComputationNodePtr = m_variableToNodeMap[rootGradientValues.begin()->first]; + m_computationNetwork->GetNestedNetwork(rootComputationNodePtr)->Backprop(FrameRange(nullptr), true, true); + + GetNetworkGradients(backPropagatedGradientValuesForInputs); + + // TODO: How to deal with the specified 'computeDevice' + } +} diff --git a/Source/CNTKv2LibraryDll/CompositeFunction.h b/Source/CNTKv2LibraryDll/CompositeFunction.h new file mode 100644 index 000000000..768b0786a --- /dev/null +++ b/Source/CNTKv2LibraryDll/CompositeFunction.h @@ -0,0 +1,267 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE.md file in the project root for full license information. +// + +#pragma once + +#include "stdafx.h" +#include "CNTKLibrary.h" +#include "PrimitiveFunction.h" +#include "ComputationNetwork.h" +#include "BackCompat.h" + +namespace CNTK +{ + class CNTKBackPropState final : public BackPropState + { + public: + CNTKBackPropState(const FunctionPtr& function, const DeviceDescriptor& computeDevice, const std::pair& evalTimeStamp) + : BackPropState(function, computeDevice), m_evalTimeStamp(evalTimeStamp) + {} + + std::pair EvalTimeStamp() const + { + return m_evalTimeStamp; + } + + private: + std::pair m_evalTimeStamp; + }; + typedef std::shared_ptr CNTKBackPropStatePtr; + + class CompositeFunction; + typedef std::shared_ptr CompositeFunctionPtr; + + class CompositeFunction final : public Function + { + friend class Function; + friend class Trainer; + friend class CompositeMinibatchSource; + friend class PackedValue; + + template + friend inline std::shared_ptr MakeSharedObject(CtorArgTypes&& ...ctorArgs); + + friend void Internal::SaveAsLegacyModel(const FunctionPtr& rootFunction, const std::wstring& modelFile); + + friend void ComputeInputPerDimMeansAndInvStdDevs(const MinibatchSourcePtr& minibatchSource, + std::unordered_map>& computedMeanAndInvStdDevs, + const DeviceDescriptor& device /*= DeviceDescriptor::CPUDevice()*/); + + static std::atomic s_nextAutoGeneratedDynamicAxis; + + static const std::wstring CompositeFunctionOpName; + + public: + static const std::wstring InternalDefaultDynamicAxisName; + static const std::wstring InternalNoSequenceAxisName; + + static Axis NextAutoGeneratedDynamicAxis() + { + static const std::wstring s_autoGeneratedDynamicAxisNamePrefix = L"autoGeneratedDynamicAxis_"; + return Axis(s_autoGeneratedDynamicAxisNamePrefix + std::to_wstring(s_nextAutoGeneratedDynamicAxis++)); + } + + public: + static CompositeFunctionPtr Create(const FunctionPtr& rootFunction, const std::wstring& name = L"", const std::wstring& uid = L"") + { + std::unordered_set visitedFunctions; + + // Call Collect to get the set of all functions in the graph + Collect(rootFunction, visitedFunctions); + + return MakeSharedObject(rootFunction, std::move(visitedFunctions), name, uid); + } + + virtual BackPropStatePtr Forward(const std::unordered_map& arguments, + std::unordered_map& outputs, + const DeviceDescriptor& computeDevice, + const std::unordered_set& outputsToRetainBackwardStateFor) override; + + virtual void Backward(const BackPropStatePtr& state, + const std::unordered_map& rootGradientValues, + std::unordered_map& backPropagatedGradientValuesForInputs) override; + + virtual Dictionary Serialize() const override; + + virtual size_t CurrentVersion() const override { return s_serializationVersion; } + + static FunctionPtr Deserialize(const Dictionary& dictionary, const CNTK::DeviceDescriptor& device); + + virtual const std::wstring& OpName() override + { + return CompositeFunctionOpName; + } + + template + static void Traverse(const FunctionPtr& rootFunction, const FunctionType& functor) + { + std::unordered_set visitedFunctions; + Traverse(rootFunction, visitedFunctions, functor); + } + + // Recursively traverses the Function graph underlying the 'rootFunction' invoking the provided functor for all visited nodes in the graph. + template + static void Traverse(const FunctionPtr& rootFunction, std::unordered_set& visitedFunctions, const FunctionType& functor) + { + visitedFunctions.insert(rootFunction); + functor(rootFunction); + + std::vector rootFunctionInputs = rootFunction->Inputs(); + for (const auto& rootInput : rootFunctionInputs) + { + if (rootInput.IsOutput() && visitedFunctions.find(rootInput.Owner()) == visitedFunctions.end()) + { + const auto& function = rootInput.Owner(); + Traverse(function, visitedFunctions, functor); + } + } + } + + private: + virtual void ReplacePlaceholdersInPlace(const std::unordered_map& placeholderReplacements, + std::unordered_set& visitedFunctions, + std::unordered_set& replacedPlaceholders) override; + + CompositeFunction(const FunctionPtr& rootFunction, std::unordered_set&& allPrimitiveFunctions, const std::wstring& name, const std::wstring& uid = Internal::GenerateUid(L"CompositeFunction")) + : Function({}, rootFunction->Outputs(), Dictionary(), rootFunction, name, uid), + m_allPrimitiveFunctions(std::move(allPrimitiveFunctions)), m_networkMatricesAllocated(false) + {} + + std::vector DetermineInputs() const + { + const auto& root = RootFunction(); + std::unordered_set visitedFunctions; + return DetermineInputs(root, visitedFunctions); + } + + // Recursively traverses the Function graph and populates the provided set of functions. + static void Collect(const FunctionPtr& rootFunction, std::unordered_set& functions) + { + // Call Traverse to get the set of all functions in the graph + Traverse(rootFunction, functions, [](const FunctionPtr& f){}); + } + + // Recursively traverses the Function graph underlying the 'rootFunction' to determine all the leaves (aka inputs) of the graph + static std::vector DetermineInputs(const FunctionPtr& rootFunction, std::unordered_set& visitedFunctions) + { + vector functions; + std::vector inputs; + std::unordered_set uniqueInputs; + Traverse(rootFunction, visitedFunctions, [&inputs, &uniqueInputs](const FunctionPtr& f){ + std::vector functionInputs = f->Inputs(); + for (auto input : functionInputs) + { + if (!input.IsOutput() && uniqueInputs.find(input) == uniqueInputs.end()) + { + inputs.push_back(input); + uniqueInputs.insert(input); + } + } + }); + + return inputs; + } + + template + Microsoft::MSR::CNTK::ComputationNetworkPtr GetComputationNetwork(const DeviceDescriptor& device, const std::unordered_set& backpropRoots, bool allocateNetworkMatrices); + + template + static Microsoft::MSR::CNTK::ComputationNodeBasePtr CreateComputationNode(const Variable& variable, + PrimitiveFunction* primitiveFunction, + const std::vector>>& inputNodes, + Microsoft::MSR::CNTK::ComputationNetworkPtr& network, + std::unordered_map& variableToNodeMap); + + template + static Microsoft::MSR::CNTK::ComputationNodeBasePtr GetOutputVariableNode(const Variable& variable, + Microsoft::MSR::CNTK::ComputationNetworkPtr& network, + Microsoft::MSR::CNTK::ComputationNetworkBuilder& builder, + std::unordered_map& variableToNodeMap, + std::unordered_map& isVariableRootMap); + + template + static Microsoft::MSR::CNTK::ComputationNodeBasePtr GetNode(const Variable& variable, Microsoft::MSR::CNTK::ComputationNetworkPtr& network, + Microsoft::MSR::CNTK::ComputationNetworkBuilder& builder, + std::unordered_map& variableToNodeMap, + std::unordered_map& isVariableRootMap); + + template + static void PopulateComputationNodeValue(const std::pair& variableValue, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode); + void PopulateNetworkInputs(const std::unordered_map& arguments); + + template + static void PopulateComputationNodeGradient(const std::pair& variableGradient, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode); + void PopulateNetworkGradients(const std::unordered_map& gradients); + + static void GetNodeOutputOrGradient(Variable var, ValuePtr& varValue, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode, bool getGradient); + void GetNetworkOutputs(std::unordered_map& outputs); + void GetNetworkGradients(std::unordered_map& gradients); + + template + static std::pair>, Microsoft::MSR::CNTK::MBLayoutPtr> GetCNTKImplMatrixAndMBLayoutFromValueObject(Variable var, const ValuePtr& value); + + template + static ValuePtr GetValueObjectFromCNTKImplMatrixAndMBLayout(const NDShape& sampleShape, const Microsoft::MSR::CNTK::Matrix& matrix, const Microsoft::MSR::CNTK::MBLayoutPtr& layout, bool readOnly = true); + template + static ValuePtr GetValueObjectFromCNTKImplMatrixAndMBLayout(Variable var, const Microsoft::MSR::CNTK::Matrix& matrix, const Microsoft::MSR::CNTK::MBLayoutPtr& layout, bool readOnly = true); + + const std::vector& GetArgumentDependencies(const Variable& output); + + private: + + // Set of all primitive functions in the graph underlying 'this' Function. Also keeps the primitive Function objects alive + // by holding strong references to them + std::unordered_set m_allPrimitiveFunctions; + + // A map from Variable objects to ComputationNode objects in the ComputationNetwork instance that implements 'this' Composite Function + std::unordered_map m_variableToNodeMap; + + // A map that tells whether a Variable in the graph underlying 'this' Function is a root of the graph + std::unordered_map m_isVariableRootMap; + + Microsoft::MSR::CNTK::ComputationNetworkPtr m_computationNetwork; + + // The backpropRoots sepecified in the most recent 'Forward' call on 'this' Function. + // This indicates for which of its roots has 'this' Function retained required intermediate + // states from the previos Forward call to be able to backpropagate gradients backwards from in + // the next 'Backward' call. + std::unordered_set m_currentBackpropRoots; + + std::unordered_map> m_perOutputVarArgumentDependencies; + + bool m_networkMatricesAllocated; + + std::unordered_map m_lastRecordedParameterValueTimeStamps; + + static const size_t s_serializationVersion = 1; + }; + + inline std::vector DynamicAxesFromInternalDynamicAxisName(const std::wstring& internalDynamicAxisName) + { + std::vector inputVarDynamicAxes; + if (internalDynamicAxisName.substr(0, CNTK::CompositeFunction::InternalDefaultDynamicAxisName.length()) == CNTK::CompositeFunction::InternalDefaultDynamicAxisName) + inputVarDynamicAxes = { CNTK::Axis::DefaultDynamicAxis(), CNTK::Axis::DefaultBatchAxis() }; + else if (internalDynamicAxisName.substr(0, CNTK::CompositeFunction::InternalNoSequenceAxisName.length()) == CNTK::CompositeFunction::InternalNoSequenceAxisName) + inputVarDynamicAxes = { CNTK::Axis::DefaultBatchAxis() }; + else + inputVarDynamicAxes = { CNTK::Axis(internalDynamicAxisName), CNTK::Axis::DefaultBatchAxis() }; + + return inputVarDynamicAxes; + } + + // Construct the dynamic axis name to be used internally for the CNTK InputNodes + inline std::wstring InternalDynamicAxisNameFromDynamicAxes(const std::vector& dynamicAxes) + { + if (dynamicAxes.empty()) + LogicError("Empty dynamic axes set"); + + if (dynamicAxes == std::vector({ Axis::DefaultBatchAxis() })) + return CompositeFunction::InternalNoSequenceAxisName; + else if (dynamicAxes == std::vector({ Axis::DefaultDynamicAxis(), Axis::DefaultBatchAxis() })) + return CompositeFunction::InternalDefaultDynamicAxisName; + else + return dynamicAxes[0].Name(); + } +} diff --git a/Source/CNTKv2LibraryDll/ComputeInputStatistics.cpp b/Source/CNTKv2LibraryDll/ComputeInputStatistics.cpp index da3acd652..2f73a0f25 100644 --- a/Source/CNTKv2LibraryDll/ComputeInputStatistics.cpp +++ b/Source/CNTKv2LibraryDll/ComputeInputStatistics.cpp @@ -6,7 +6,7 @@ #include "stdafx.h" #include "CNTKLibrary.h" #include "Utils.h" -#include "Function.h" +#include "CompositeFunction.h" #include #include "ComputationNetworkBuilder.h" diff --git a/Source/CNTKv2LibraryDll/Function.cpp b/Source/CNTKv2LibraryDll/Function.cpp index 093fc3215..796b67383 100644 --- a/Source/CNTKv2LibraryDll/Function.cpp +++ b/Source/CNTKv2LibraryDll/Function.cpp @@ -5,20 +5,8 @@ #include "stdafx.h" #include "CNTKLibrary.h" -#include "Function.h" -#include "ComputationNetworkBuilder.h" -#include "Utils.h" -#include "ComputationNode.h" -#include "ReshapingNodes.h" -#include "EvaluationNodes.h" -#include "TrainingNodes.h" -#include "LinearAlgebraNodes.h" -#include "InputAndParamNodes.h" -#include "NonlinearityNodes.h" -#include "RecurrentNodes.h" -#include "Serialization.h" -#include "Value.h" -#include "RNNNodes.h" +#include "PrimitiveFunction.h" +#include "CompositeFunction.h" using namespace Microsoft::MSR::CNTK; @@ -543,2127 +531,6 @@ namespace CNTK }); } - // Names for the reduction operations as used by the CNTK ReduceElementsNode - /*static*/ const std::wstring PrimitiveFunction::InternalSumReductionOpName = L"Sum"; - /*static*/ const std::wstring PrimitiveFunction::InternalLogSumReductionOpName = L"LogSum"; - /*static*/ const std::wstring PrimitiveFunction::InternalMeanReductionOpName = L"Mean"; - /*static*/ const std::wstring PrimitiveFunction::InternalMaxReductionOpName = L"Max"; - /*static*/ const std::wstring PrimitiveFunction::InternalMinReductionOpName = L"Min"; - /*static*/ const std::wstring PrimitiveFunction::InternalAllReductionOpName = L"All"; - /*static*/ const std::wstring PrimitiveFunction::InternalAnyReductionOpName = L"Any"; - - // Names of the various attributes of CNTK primitive Functions - /*static*/ const std::wstring PrimitiveFunction::AttributeNameAxis = L"axis"; - /*static*/ const std::wstring PrimitiveFunction::AttributeNameAxis1 = L"axis1"; - /*static*/ const std::wstring PrimitiveFunction::AttributeNameAxis2 = L"axis2"; - /*static*/ const std::wstring PrimitiveFunction::AttributeNameAllowDuplicates = L"allowDuplicates"; - /*static*/ const std::wstring PrimitiveFunction::AttributeNameNumSamples = L"numSamples"; - /*static*/ const std::wstring PrimitiveFunction::AttributeNameDropoutRate = L"dropoutRate"; - /*static*/ const std::wstring PrimitiveFunction::AttributeNameNewShape = L"newShape"; - /*static*/ const std::wstring PrimitiveFunction::AttributeNameOutputRank = L"outputRank"; - /*static*/ const std::wstring PrimitiveFunction::AttributeNameInferInputRankToMap = L"inferInputRankToMap"; - /*static*/ const std::wstring PrimitiveFunction::AttributeNameOffset = L"offset"; - /*static*/ const std::wstring PrimitiveFunction::AttributeNameStrides = L"strides"; - /*static*/ const std::wstring PrimitiveFunction::AttributeNameSharing = L"sharing"; - /*static*/ const std::wstring PrimitiveFunction::AttributeNameAutoPadding = L"autoPadding"; - /*static*/ const std::wstring PrimitiveFunction::AttributeNameLowerPad = L"lowerPad"; - /*static*/ const std::wstring PrimitiveFunction::AttributeNameUpperPad = L"upperPad"; - /*static*/ const std::wstring PrimitiveFunction::AttributeNameTranspose = L"transpose"; - /*static*/ const std::wstring PrimitiveFunction::AttributeNameMaxTempMemSizeInSamples = L"maxTempMemSizeInSamples"; - /*static*/ const std::wstring PrimitiveFunction::AttributeNameROIOutputShape = L"roiOutputShape"; - /*static*/ const std::wstring PrimitiveFunction::AttributeNamePoolingType = L"poolingType"; - /*static*/ const std::wstring PrimitiveFunction::AttributeNamePoolingWindowShape = L"poolingWindowShape"; - /*static*/ const std::wstring PrimitiveFunction::AttributeNameSpatial = L"spatial"; - /*static*/ const std::wstring PrimitiveFunction::AttributeNameNormalizationTimeConstant = L"normalizationTimeConstant"; - /*static*/ const std::wstring PrimitiveFunction::AttributeNameBlendTimeConstant = L"blendTimeConstant"; - /*static*/ const std::wstring PrimitiveFunction::AttributeNameEpsilon = L"epsilon"; - /*static*/ const std::wstring PrimitiveFunction::AttributeNameUseCuDNNEngine = L"useCuDNNEngine"; - /*static*/ const std::wstring PrimitiveFunction::AttributeNameNewDynamicAxes = L"newDynamicAxes"; - /*static*/ const std::wstring PrimitiveFunction::AttributeNameNewSequenceAxisLengthScalingFactor = L"newSequenceAxisLengthScalingFactor"; - /*static*/ const std::wstring PrimitiveFunction::AttributeNameNewSequenceAxisLengthAdditiveFactor = L"newSequenceAxisLengthAdditiveFactor"; - /*static*/ const std::wstring PrimitiveFunction::AttributeNameBeginIndex = L"beginIndex"; - /*static*/ const std::wstring PrimitiveFunction::AttributeNameEndIndex = L"endIndex"; - /*static*/ const std::wstring PrimitiveFunction::AttributeNameReductionOpName = L"reductionOpName"; - /*static*/ const std::wstring PrimitiveFunction::AttributeNameBidirectional = L"bidirectional"; - /*static*/ const std::wstring PrimitiveFunction::AttributeNameNumLayers = L"numLayers"; - /*static*/ const std::wstring PrimitiveFunction::AttributeNameHiddenSize = L"hiddenSize"; - /*static*/ const std::wstring PrimitiveFunction::AttributeNameRecurrentOp = L"recurrentOp"; - - /*static*/ std::vector PrimitiveFunction::GetOutputVariables(PrimitiveOpType op, - std::vector& inputs, - Function* owner, - Dictionary& functionConfig, - bool inferDimensions, - const std::wstring& functionName) - { - if (op == PrimitiveOpType::Combine) - return inputs; - - // We use the first non-constant input operand's DataType as the output DataType - // In case there are no non-constant known DataTypes, we just pick the first known operand DataType - // Also, all the known DataTypes of operands should match except for constants where coercion is allowed - DataType firstKnownInputDataType = DataType::Unknown; - DataType outputDataType = DataType::Unknown; - size_t i = 0; - while (i < inputs.size()) - { - auto input = inputs[i++]; - auto inputDataType = input.GetDataType(); - if (inputDataType != DataType::Unknown) - { - if (firstKnownInputDataType == DataType::Unknown) - firstKnownInputDataType = inputDataType; - - if (outputDataType == DataType::Unknown) - { - if (!input.IsConstant()) - outputDataType = inputDataType; - } - else - { - // The DataType of all operands should match except for Constants where we allow coercion - if ((inputDataType != DataType::Unknown) && (inputDataType != outputDataType) && !input.IsConstant()) - InvalidArgument("Primitive function with op type %S has operands with different DataTypes %s and %s", PrimitiveOpTypeName(op).c_str(), DataTypeName(outputDataType), DataTypeName(inputDataType)); - } - } - } - - if (outputDataType == DataType::Unknown) - outputDataType = firstKnownInputDataType; - - // We currently require that the inputs' dynamic axes, if any, match - std::vector outputDynamicAxes; - if ((op == PrimitiveOpType::SumAll) || - (op == PrimitiveOpType::SquaredError) || - (op == PrimitiveOpType::CrossEntropyWithSoftmax) || - (op == PrimitiveOpType::ClassificationError) || - (op == PrimitiveOpType::Logistic)) - { - outputDynamicAxes = std::vector({}); - } - else if (op == PrimitiveOpType::Where) - { - if (functionConfig.Contains(PrimitiveFunction::AttributeNameNewDynamicAxes)) - outputDynamicAxes = AsVector(functionConfig[PrimitiveFunction::AttributeNameNewDynamicAxes].Value>()); - else - { - if (inputs[0].DynamicAxes() == Axis::UnknownDynamicAxes()) - outputDynamicAxes = Axis::UnknownDynamicAxes(); - else - { - if (functionConfig.Contains(PrimitiveFunction::AttributeNameNewSequenceAxisLengthScalingFactor) && - functionConfig.Contains(PrimitiveFunction::AttributeNameNewSequenceAxisLengthAdditiveFactor)) - { - size_t newSequenceAxisLengthScalingFactor = functionConfig[PrimitiveFunction::AttributeNameNewSequenceAxisLengthScalingFactor].Value(); - int newSequenceAxisLengthAdditiveFactor = functionConfig[PrimitiveFunction::AttributeNameNewSequenceAxisLengthAdditiveFactor].Value(); - - auto derivedDynamicAxes = GetDerivedDynamicAxes(inputs[0].DynamicAxes()[0], newSequenceAxisLengthScalingFactor, newSequenceAxisLengthAdditiveFactor); - std::copy(derivedDynamicAxes.begin(), derivedDynamicAxes.end(), std::back_inserter(outputDynamicAxes)); - } - else - { - outputDynamicAxes.push_back(Axis::NewUniqueDynamicAxis(L"whereNodeDynamicAxis")); - } - - for (size_t i = 1; i < inputs[0].DynamicAxes().size(); ++i) - outputDynamicAxes.push_back(inputs[0].DynamicAxes()[i]); - - functionConfig[PrimitiveFunction::AttributeNameNewDynamicAxes] = AsDictionaryValueVector(outputDynamicAxes); - } - } - } - else if (op == PrimitiveOpType::ScatterPacked) - outputDynamicAxes = inputs[2].DynamicAxes(); - else if ((op == PrimitiveOpType::PackedIndex) || (op == PrimitiveOpType::GatherPacked)) - outputDynamicAxes = inputs[1].DynamicAxes(); - else if (op == PrimitiveOpType::ReconcileDynamicAxis) - outputDynamicAxes = inputs[1].DynamicAxes(); - else - { - auto allInputDynamicAxesEmpty = std::find_if(inputs.begin(), inputs.end(), [](const Variable& input) { return !input.DynamicAxes().empty(); }) == inputs.end(); - if (!allInputDynamicAxesEmpty) - { - outputDynamicAxes = Axis::UnknownDynamicAxes(); - for (auto inputVar : inputs) - { - auto currentInputDynamicAxes = inputVar.DynamicAxes(); - if (!currentInputDynamicAxes.empty() && (currentInputDynamicAxes != Axis::UnknownDynamicAxes())) - { - if (outputDynamicAxes == Axis::UnknownDynamicAxes()) - outputDynamicAxes = currentInputDynamicAxes; - else - { - if (currentInputDynamicAxes != outputDynamicAxes) - LogicError("Currently if an operand of a elementwise operation has any dynamic axes, those must match the dynamic axes of the other operands"); - } - } - } - } - } - - NDShape outputShape; - bool areAnyInputShapesUnknown = (std::find_if(inputs.begin(), inputs.end(), [](const Variable& input) { return input.Shape().IsUnknown(); }) != inputs.end()); - if (areAnyInputShapesUnknown) - outputShape = NDShape::Unknown; - else - { - switch (op) - { - case PrimitiveOpType::Negate: - case PrimitiveOpType::Sigmoid: - case PrimitiveOpType::Tanh: - case PrimitiveOpType::ReLU: - case PrimitiveOpType::Exp: - case PrimitiveOpType::Log: - case PrimitiveOpType::Sqrt: - case PrimitiveOpType::Floor: - case PrimitiveOpType::Abs: - case PrimitiveOpType::Reciprocal: - case PrimitiveOpType::Softmax: - case PrimitiveOpType::Hardmax: - case PrimitiveOpType::Dropout: - case PrimitiveOpType::Where: - case PrimitiveOpType::LogSoftmax: - { - assert(inputs.size() == 1); - outputShape = UnaryElementwiseOpOutputShape(inputs[0].Shape()); - break; - } - case PrimitiveOpType::TransposeAxes: - { - assert(inputs.size() == 1); - - auto axis1 = NormalizeStaticAxis(functionConfig[PrimitiveFunction::AttributeNameAxis1].Value(), inputs[0].Shape()); - auto axis2 = NormalizeStaticAxis(functionConfig[PrimitiveFunction::AttributeNameAxis2].Value(), inputs[0].Shape()); - - if (!axis1.IsStaticAxis() || !axis2.IsStaticAxis()) - LogicError("TransposeAxes operation currently does not support transposing dynamic axes"); - - VerifyStaticAxis(axis1, inputs[0].Shape()); - VerifyStaticAxis(axis2, inputs[0].Shape()); - - outputShape = inputs[0].Shape(); - std::swap(outputShape[axis1.StaticAxisIndex()], outputShape[axis2.StaticAxisIndex()]); - break; - } - case PrimitiveOpType::Slice: - { - assert(inputs.size() == 1); - auto axis = NormalizeStaticAxis(functionConfig[PrimitiveFunction::AttributeNameAxis].Value(), inputs[0].Shape()); - - auto beginIndex = functionConfig[PrimitiveFunction::AttributeNameBeginIndex].Value(); - auto endIndex = functionConfig[PrimitiveFunction::AttributeNameEndIndex].Value(); - if (!axis.IsStaticAxis()) - LogicError("Built-in Slice operation currently does not support slicing along dynamic axis"); - - VerifyStaticAxis(axis, inputs[0].Shape()); - - size_t sliceAxisDim = inputs[0].Shape()[axis.StaticAxisIndex()]; - int realBeginIndex = (beginIndex >= 0) ? beginIndex : beginIndex + sliceAxisDim; - int realEndIndex = (endIndex > 0) ? endIndex : endIndex + sliceAxisDim; - if ((sliceAxisDim < realEndIndex) || (realEndIndex < realBeginIndex) || (realBeginIndex < 0)) - RuntimeError("Slice operation: Index range [%d,%d), interpreted as [%d,%d), is invalid for input's shape ([%S]).", - beginIndex, - endIndex, - realBeginIndex, - realEndIndex, - AsStringForErrorReporting(inputs[0].Shape()).c_str()); - - auto outputTensorShape = AsTensorShape(inputs[0].Shape()); - - // propagate as much as we can - if ((axis.StaticAxisIndex() < (int)outputTensorShape.GetRank()) && (0 <= realBeginIndex) && (realBeginIndex <= realEndIndex) && (realEndIndex <= sliceAxisDim)) - outputTensorShape.NarrowTo(axis.StaticAxisIndex(), realBeginIndex, realEndIndex); - - outputShape = AsNDShape(outputTensorShape, /*allowNonFlattenableTensorShapes = */ true); - break; - } - case PrimitiveOpType::Reshape: - { - auto newShape = functionConfig[PrimitiveFunction::AttributeNameNewShape].Value(); - outputShape = ReshapeOutputShape(inputs[0].Shape(), newShape); - break; - } - case PrimitiveOpType::ROIPooling: - { - assert(inputs.size() == 2); - auto convMapShape = inputs[0].Shape(); - auto roisShape = inputs[1].Shape(); - auto roiOutputShape = functionConfig[PrimitiveFunction::AttributeNameROIOutputShape].Value(); - - auto outW = roiOutputShape[0]; - auto outH = roiOutputShape[1]; - auto numChannels = convMapShape[2]; - auto roisPerImage = roisShape[1]; - - if (roiOutputShape.Rank() != 2) - InvalidArgument("ROIPoolingNode: roi output shape must have two dimensions ([W x H])."); - - if (convMapShape[0] < outW || convMapShape[1] < outH) - InvalidArgument("ROIPoolingNode: inputWidth must >= windowWidth and inputHeight must >= windowHeight."); - - if (convMapShape[2] < 1) - InvalidArgument("ROIPoolingNode: input must have at least one channel ([W x H x C])."); - - if (roisShape[0] != 4) - InvalidArgument("ROIPoolingNode: ROI input must have the following shape: [4 x roisPerImage]."); - - if (roisPerImage < 1) - InvalidArgument("ROIPoolingNode: ROI input must contain at least one ROI ([4 x roisPerImage])."); - - outputShape = { outW, outH, numChannels, roisPerImage }; - break; - } - case PrimitiveOpType::Pooling: - { - assert(inputs.size() == 1); - auto poolingWindowsShape = functionConfig[PrimitiveFunction::AttributeNamePoolingWindowShape].Value(); - auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value(); - auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value(); - auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value(); - auto autoPadding = AsVector(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value>()); - NDShape outputMapCount = { 1 }; - std::vector sharing = { true }; - auto inputShape = inputs[0].Shape(); - - // In case of pooling if the kernel shape is unknown, then treat it as global pooling. - if (poolingWindowsShape == NDShape::Unknown) - { - if ((std::find(autoPadding.begin(), autoPadding.end(), true) != autoPadding.end()) || - (lowerPad.TotalSize() > 0) || (upperPad.TotalSize() > 0)) - RuntimeError("Padding isn't allowed for Unknown shape!"); - - poolingWindowsShape = inputShape.SubShape(0, inputShape.Rank()-1); - functionConfig[PrimitiveFunction::AttributeNamePoolingWindowShape] = poolingWindowsShape; - } - - outputShape = ConvolutionOpOutputShape(op, inputShape, poolingWindowsShape, outputMapCount, strides, sharing, autoPadding, lowerPad, upperPad, false, inferDimensions); - break; - } - case PrimitiveOpType::SumAll: - assert(inputs.size() == 1); - outputShape = {1}; - break; - case PrimitiveOpType::Plus: - case PrimitiveOpType::LogPlus: - case PrimitiveOpType::Minus: - case PrimitiveOpType::ElementTimes: - case PrimitiveOpType::Equal: - case PrimitiveOpType::NotEqual: - case PrimitiveOpType::Less: - case PrimitiveOpType::LessEqual: - case PrimitiveOpType::Greater: - case PrimitiveOpType::GreaterEqual: - case PrimitiveOpType::PastValue: - case PrimitiveOpType::FutureValue: - { - assert(inputs.size() == 2); - if ((op == PrimitiveOpType::PastValue) || (op == PrimitiveOpType::FutureValue)) - { - Variable inputOperandVar = inputs[0]; - Variable initialStateVar = inputs[1]; - - // TODO: We currently only support input operand with 1 dynamic axis for PastValue/FutureValue - if ((inputOperandVar.DynamicAxes() != Axis::UnknownDynamicAxes()) && (inputOperandVar.DynamicAxes().size() != 2)) - LogicError("Currently PastValue/FutureValue Function only supports input operand with 2 dynamic axis (1 sequence-axis and 1 batch-axis)"); - - if (!initialStateVar.DynamicAxes().empty()) - LogicError("Currently PastValue/FutureValue Function does not support initial state operand with dynamic axes!"); - } - - outputShape = BinaryElementwiseOpOutputShape(op, inputs[0], inputs[1], true, inferDimensions); - break; - } - case PrimitiveOpType::Times: - { - assert(inputs.size() == 2); - auto outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value(); - auto inferInputRankToMap = functionConfig[PrimitiveFunction::AttributeNameInferInputRankToMap].Value(); - outputShape = TimesOpOutputShape(inputs[0], inputs[1], outputRank, inferInputRankToMap, inferDimensions); - break; - } - case PrimitiveOpType::TransposeTimes: - { - assert(inputs.size() == 2); - - auto transposeShapeFunc = [](const NDShape& shape) { - NDShape transposedShape(std::max(2, shape.Rank()), 1); - for (size_t i = 0; i < shape.Rank(); ++i) - transposedShape[transposedShape.Rank() - i - 1] = shape[i]; - - return transposedShape; - }; - - if (inputs[0].Shape().Rank() > 2) - LogicError("TransposeTimes operation currently only supports %s operands of rank 1 or 2", Internal::IsReversingTensorShapesInErrorMessagesEnabled() ? "right" : "left"); - - NDShape transposedLeftOperandShape = transposeShapeFunc(inputs[0].Shape()); - Variable dummyLeftOperand = PlaceholderVariable(transposedLeftOperandShape); - size_t outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value(); - outputShape = TimesOpOutputShape(dummyLeftOperand, inputs[1], outputRank, -1, inferDimensions); - if (dummyLeftOperand.Shape() != transposedLeftOperandShape) - inputs[0].m_dataFields->m_shape = transposeShapeFunc(dummyLeftOperand.Shape()); - - break; - } - case PrimitiveOpType::Convolution: - { - assert(inputs.size() == 2); - auto& strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value(); - auto& lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value(); - auto& upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value(); - auto sharing = AsVector(functionConfig[PrimitiveFunction::AttributeNameSharing].Value>()); - auto autoPadding = AsVector(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value>()); - bool transpose = functionConfig[PrimitiveFunction::AttributeNameTranspose].Value(); - if (inputs[0].Shape().Rank() < inputs[1].Shape().Rank()) - InvalidArgument("The convolution map should have at least as many axes as the shape of the input it operates on!"); - - NDShape outputMapCount, kernelShape; - std::tie(outputMapCount, kernelShape) = GetConvolutionOutputMapCountAndKernelShape(inputs[0].Shape(), inputs[1].Shape()); - auto originalKernelShape = kernelShape; - outputShape = ConvolutionOpOutputShape(op, inputs[1].Shape(), kernelShape, outputMapCount, strides, sharing, autoPadding, lowerPad, upperPad, transpose, inferDimensions); - if (originalKernelShape != kernelShape) - { - for (size_t i = 0; i < kernelShape.Rank(); ++i) - inputs[0].m_dataFields->m_shape[i] = kernelShape[i]; - } - - functionConfig[PrimitiveFunction::AttributeNameSharing] = AsDictionaryValueVector(sharing); - functionConfig[PrimitiveFunction::AttributeNameAutoPadding] = AsDictionaryValueVector(autoPadding); - break; - } - case PrimitiveOpType::Logistic: - case PrimitiveOpType::SquaredError: - case PrimitiveOpType::CrossEntropyWithSoftmax: - case PrimitiveOpType::ClassificationError: - { - if ((op == PrimitiveOpType::ClassificationError) || (op == PrimitiveOpType::Logistic)) - assert(inputs.size() >= 2); - else - assert(inputs.size() == 2); - - if ((inputs[0].Shape().Rank() > 2) || ((inputs[0].Shape().Rank() > 1) && (inputs[0].Shape()[1] != 1))) - InvalidArgument("The shape of input operands for the %S operation should have at most one axis", PrimitiveOpTypeName(op).c_str()); - - auto predictionShape = inputs[0].Shape(); - auto labelsShape = inputs[1].Shape(); - if (predictionShape != labelsShape) - RuntimeError("Prediction output operand's shape %S is incompatible with label operand's shape %S for the %S operation", AsStringForErrorReporting(predictionShape).c_str(), AsStringForErrorReporting(labelsShape).c_str(), PrimitiveOpTypeName(op).c_str()); - - std::vector reductionAxes; - for (int i = 0; i < (int)inputs[0].Shape().Rank(); ++i) - reductionAxes.push_back(i); - - outputShape = ReductionOpOutputShape(op, predictionShape, reductionAxes, /*preserveReductionAxes =*/ false); - break; - } - case PrimitiveOpType::ReduceElements: - { - assert(inputs.size() == 1); - auto reductionAxis = NormalizeStaticAxis(functionConfig[PrimitiveFunction::AttributeNameAxis].Value(), inputs[0].Shape()); - if (reductionAxis == Axis::AllStaticAxes()) - outputShape = {}; - else - { - std::vector reductionAxes = { reductionAxis.StaticAxisIndex() }; - outputShape = ReductionOpOutputShape(op, inputs[0].Shape(), reductionAxes, /*preserveReductionAxes =*/ true); - } - break; - } - case PrimitiveOpType::BatchNormalization: - { - assert(inputs.size() == 5); - auto spatial = functionConfig[PrimitiveFunction::AttributeNameSpatial].Value(); - outputShape = BatchNormalizationOutputShape(inputs, spatial, inferDimensions); - break; - } - case PrimitiveOpType::PackedIndex: - outputShape = UnaryElementwiseOpOutputShape(inputs[1].Shape()); - break; - case PrimitiveOpType::GatherPacked: - { - bool sourceHasDynamicAxis = !inputs[0].DynamicAxes().empty(); - - // inherit tensor dimension from sourceData, minus the last (column or time) dimension. TODO this needs to become simpler... - if (sourceHasDynamicAxis) - outputShape = inputs[0].Shape(); - else - { - if (inputs[0].Shape().Rank() > 1) - outputShape = outputShape.SubShape(0, outputShape.Rank() - 1); - else - outputShape = {}; - } - - break; - } - case PrimitiveOpType::ScatterPacked: - { - if (inputs[0].DynamicAxes().empty() || inputs[1].DynamicAxes().empty() || inputs[2].DynamicAxes().empty()) - InvalidArgument("ScatterPacked requires all its operands to have dynamic axes"); - - outputShape = inputs[0].Shape(); - break; - } - case PrimitiveOpType::Clip: - assert(inputs.size() == 3); - outputShape = UnaryElementwiseOpOutputShape(inputs[0].Shape()); - break; - case PrimitiveOpType::Select: - assert(inputs.size() == 3); - outputShape = NaryElementwiseOpOutputShape(op, inputs, true); - break; - case PrimitiveOpType::Splice: - { - assert(inputs.size() >= 2); - auto maxInputRank = MaxInputRank(inputs); - auto spliceAxis = NormalizeStaticAxis(functionConfig[PrimitiveFunction::AttributeNameAxis].Value(), NDShape(maxInputRank)); - - if (!spliceAxis.IsStaticAxis()) - LogicError("Splice operation currently does not support splicing along dynamic axis"); - - if (spliceAxis.StaticAxisIndex() < 0) - InvalidArgument("Splice: The axis argument's static axis index must be >= 0!"); - - outputShape = SpliceOutputShape(inputs, spliceAxis.StaticAxisIndex()); - break; - } - case PrimitiveOpType::RandomSample: - case PrimitiveOpType::RandomSampleInclusionFrequency: - { - auto numSamples = functionConfig[PrimitiveFunction::AttributeNameNumSamples].Value(); - auto allowDuplicates = functionConfig[PrimitiveFunction::AttributeNameAllowDuplicates].Value(); - - if (numSamples == 0) - InvalidArgument("Number of requested samples is zero."); - - let& shape = inputs[0].Shape(); - size_t numClasses = shape.Dimensions()[0]; - - if (numClasses != NDShape::InferredDimension && !allowDuplicates && numClasses <= numSamples) - InvalidArgument("For sampling without duplicates the number of requested samples (%lu) needs to be less than the number of classes (%lu).", numSamples, numClasses); - - // within this block we handle RandomSample and RandomSampleInclusionFrequency - if (op == PrimitiveOpType::RandomSampleInclusionFrequency) - outputShape = shape; - else - { - vector dimensions{ numSamples, numClasses }; - outputShape = NDShape(dimensions); - } - - break; - } - case PrimitiveOpType::OptimizedRNNStack: - { - assert(inputs.size() == 2); - auto operand = inputs[0]; - auto parameter = inputs[1]; - if (operand.Shape().Rank() != 1) - InvalidArgument("OptimizedRNNStack: input must have rank 1; actual input rank is %lu", operand.Shape().Rank()); - if (operand.DynamicAxes().empty()) - InvalidArgument("OptimizedRNNStack: input must have at least one dynamic axis"); - auto numLayers = functionConfig[PrimitiveFunction::AttributeNameNumLayers].Value(); - if (numLayers == 0) - InvalidArgument("Number of layers in OptimizedRNNStack operation should be positive"); - auto bidirectional = functionConfig[PrimitiveFunction::AttributeNameBidirectional].Value(); - auto hiddenSize = functionConfig[PrimitiveFunction::AttributeNameHiddenSize].Value(); - - // output dims - outputShape = operand.Shape(); - outputShape[0] = (bidirectional ? 2 : 1) * hiddenSize; - // infer input size - // Note: Output dim is second axis, so say initOutputRank=-1. - if (parameter.Shape().Rank() == 2) - { - const auto recurrentOp = functionConfig[PrimitiveFunction::AttributeNameRecurrentOp].Value(); - const auto attributes = RnnAttributes(bidirectional, numLayers, hiddenSize, recurrentOp, -1); - const auto numParameters = attributes.GetNumParameters(operand.Shape().TotalSize()); - std::vector> newOperandShapes = { { parameter, std::move(NDShape({ numParameters.first, numParameters.second })) } }; - UpdateOperandShapes(newOperandShapes); - } - break; - } - case PrimitiveOpType::ReconcileDynamicAxis: - { - assert(inputs.size() == 2); - auto operand = inputs[0]; - auto layout = inputs[1]; - if (operand.DynamicAxes().empty()) - InvalidArgument("ReconcileDynamicAxis: input must have at least one dynamic axis"); - if (layout.DynamicAxes().empty()) - InvalidArgument("ReconcileDynamicAxis: layout must have at least one dynamic axis"); - outputShape = operand.Shape(); - break; - } - default: - LogicError("Specified op %S not yet supported", PrimitiveOpTypeName(op).c_str()); - break; - } - } - - return{ OutputVariable(outputShape, outputDataType, owner, outputDynamicAxes, functionName.empty() ? L"" : functionName + L"_output") }; - } - - /*static*/ const std::wstring CompositeFunction::CompositeFunctionOpName = L"CompositeFunctionOpName"; - /*static*/ std::atomic CompositeFunction::s_nextAutoGeneratedDynamicAxis(0); - - static const std::wstring s_primitiveFunctionTypeValue = L"PrimitiveFunction"; - - /*virtual*/ Dictionary PrimitiveFunction::Serialize() const - { - Dictionary dict; - - dict[versionKey] = CurrentVersion(); - dict[typeKey] = s_primitiveFunctionTypeValue; - dict[opKey] = static_cast(m_op); - dict[attributesKey] = Attributes(); - dict[uidKey] = Uid(); - dict[nameKey] = Name(); - - auto inputs = Inputs(); - vector inputUids; - inputUids.reserve(inputs.size()); - for (auto& input : inputs) - { - inputUids.push_back(input.Uid()); - } - - dict[inputsKey] = std::move(inputUids); - - return dict; - } - - /*static*/ FunctionPtr PrimitiveFunction::Deserialize(const Dictionary& dict, - const std::unordered_map& uidToVariableMap, - const CNTK::DeviceDescriptor& device) - { - static const vector s_requiredDictionaryKeys = { typeKey, opKey, uidKey, attributesKey, inputsKey, nameKey }; - - size_t version = ValidateDictionary(dict, s_requiredDictionaryKeys, s_primitiveFunctionTypeValue, s_serializationVersion); - - PrimitiveOpType op = PrimitiveOpType(dict[opKey].Value()); - - // The hard requirement that the serialization depends on is that - // new op type values are only added to the end of the list, after Combine. - // This also applies to other enums (DataType, VariableKind, etc.) - if (op > PrimitiveOpType::Combine) - { - LogicError("Unexpected variable '%ls':'%u' (%s).", - opKey.c_str(), - static_cast::type>(op), - GetVersionsString(s_serializationVersion, version).c_str()); - } - - const auto& uid = dict[uidKey].Value(); - const auto& name = dict[nameKey].Value(); - auto attributes = dict[attributesKey].Value(); - const auto& inputUids = dict[inputsKey].Value>(); - - std::vector inputs; - inputs.reserve(inputUids.size()); - - for (const auto& dictionaryValue : inputUids) - { - const auto& inputUid = dictionaryValue.Value(); - if (uidToVariableMap.find(inputUid) == uidToVariableMap.end()) - { - LogicError("There are no inputs corresponging to input uid = '%ls' " - "(%s).", inputUid.c_str(), GetVersionsString(s_serializationVersion, version).c_str()); - } - inputs.push_back(uidToVariableMap.at(inputUid)); - } - - return std::shared_ptr(new PrimitiveFunction(op, inputs, std::move(attributes), name, uid), - [](PrimitiveFunction* ptr) { delete ptr; }); - } - - static const std::wstring s_compositeFunctionTypeValue = L"CompositeFunction"; - - /*virtual*/ Dictionary CompositeFunction::Serialize() const - { - Dictionary dict; - - dict[versionKey] = CurrentVersion(); - dict[typeKey] = s_compositeFunctionTypeValue; - dict[rootKey] = RootFunction()->Uid(); - dict[nameKey] = Name(); - dict[uidKey] = Uid(); - - - // Find cycles in the graph and "break" them by inserting placeholders. - // This needs to be done on Save, since here we have easy access to the shape and - // dynamic axis info. - std::unordered_set visitedFunctions; - std::vector topoSortedPrimitiveFunctions; - std::vector inputs; - std::unordered_set inputUids; - Traverse(RootFunction(), [&visitedFunctions, &inputs, &topoSortedPrimitiveFunctions, &inputUids](const FunctionPtr& function) { - std::vector functionInputs = function->Inputs(); - for (const auto& input : functionInputs) - { - auto& uid = input.Uid(); - if (inputUids.find(uid) != inputUids.end()) - { - continue; - } - - // check if this input corresponds to a cyclic edge in the graph. - bool mustBeReplaced = input.IsOutput() && visitedFunctions.find(input.Owner()) != visitedFunctions.end(); - - if (mustBeReplaced) - { - auto varKind = VariableKind::Placeholder; - Variable var(input.Shape(), varKind, input.GetDataType(), nullptr, - input.IsSparse(), input.DynamicAxes(), input.Name(), uid); - inputs.push_back(var); - inputUids.insert(uid); - } - else if (!input.IsOutput()) - { - // leave the input as is. - inputs.push_back(input); - inputUids.insert(uid); - } - } - visitedFunctions.insert(function); - topoSortedPrimitiveFunctions.push_back(function); - }); - - std::reverse(std::begin(topoSortedPrimitiveFunctions), std::end(topoSortedPrimitiveFunctions)); - - assert(topoSortedPrimitiveFunctions.size() == m_allPrimitiveFunctions.size()); - assert(topoSortedPrimitiveFunctions.back()->Uid() == RootFunction()->Uid()); - - std::vector inputDictionaries; - inputDictionaries.reserve(inputs.size()); - inputUids.clear(); - for (const auto& input : inputs) - { - if (inputUids.find(input.Uid()) != inputUids.end()) - { - LogicError("Input uids must be unique"); - } - inputUids.insert(input.Uid()); - inputDictionaries.push_back(input.Serialize()); - } - - dict[inputsKey] = std::move(inputDictionaries); - - std::vector functionDictionaries; - std::unordered_set outputUids; - for (const auto& primitiveFunciton : topoSortedPrimitiveFunctions) - { - for (const auto& output : primitiveFunciton->Outputs()) - { - if (outputUids.find(output.Uid()) != outputUids.end()) - { - LogicError("Output uids of all primitive functions in a function graph must be unique"); - } - outputUids.insert(primitiveFunciton->Uid()); - } - functionDictionaries.push_back(primitiveFunciton->Serialize()); - } - - dict[functionsKey] = std::move(functionDictionaries); - - return dict; - } - - /*static*/ FunctionPtr CompositeFunction::Deserialize(const Dictionary& dict, const CNTK::DeviceDescriptor& device) - { - static const vector s_requiredDictionaryKeys = { typeKey, rootKey, nameKey, uidKey, inputsKey, functionsKey }; - - size_t version = ValidateDictionary(dict, s_requiredDictionaryKeys, s_compositeFunctionTypeValue, s_serializationVersion); - - const auto& rootUid = dict[rootKey].Value(); - const auto& name = dict[nameKey].Value(); - const auto& uid = dict[uidKey].Value(); - const auto& inputs = dict[inputsKey].Value>(); - - std::unordered_map uidToInputMap(inputs.size()); - - for (const auto& dictionaryValue : inputs) - { - const auto& dictionary = dictionaryValue.Value(); - const auto& inputVar = Variable::Deserialize(dictionary, device); - - if (uidToInputMap.find(inputVar.Uid()) != uidToInputMap.end()) - { - LogicError("Input uids are not unique (several inputs share '%ls' uid) " - "(%s).", inputVar.Uid().c_str(), GetVersionsString(s_serializationVersion, version).c_str()); - } - uidToInputMap[inputVar.Uid()] = inputVar; - } - - const auto& functions = dict[functionsKey].Value>(); - - FunctionPtr root; - std::unordered_map placeholderReplacements; - std::unordered_set allPrimitiveFunctions; // this keeps all primitive functions alive until a composite function is created. - for (const auto& dictionaryValue : functions) - { - root = PrimitiveFunction::Deserialize(dictionaryValue.Value(), uidToInputMap, device); - allPrimitiveFunctions.insert(root); - - auto primitiveFunction = dynamic_cast(root.get()); - // Since Combine simply forwards other functions' outputs, all of its outputs - // should already be in the uidToInputMap. - if (primitiveFunction->OpType() == PrimitiveOpType::Combine) - { - continue; - } - - for (const auto& output : root->Outputs()) - { - const auto& it = uidToInputMap.find(output.Uid()); - if (it != uidToInputMap.end()) - { - if (!it->second.IsPlaceholder()) - { - LogicError("Unexpected variable type %ls instead of a Placeholder for input %ls variable (uid = %ls)" - "(%s).", VariableKindName(it->second.Kind()), it->second.Name().c_str(), it->second.Uid().c_str(), - GetVersionsString(s_serializationVersion, version).c_str()); - } - placeholderReplacements[it->second] = output; - } - else - { - uidToInputMap[output.Uid()] = output; - } - } - } - - if (root->Uid() != rootUid) - { - LogicError("Root UID '%ls' is different from the expected value '%ls'.", root->Uid().c_str(), rootUid.c_str()); - } - - if (placeholderReplacements.size() > 0) - { - return CompositeFunction::Create(root->ReplacePlaceholders(placeholderReplacements), name, uid); - } - - return CompositeFunction::Create(root, name, uid); - } - - // Names of the dynamic axes in the CNTK engine for some special sets of dynamic axes values - // Note: The no sequence axis corresponds to a special case where there is no sequence axis (i.e. has been reduced over) - // and the special name is used to identify this when loading back a model saved in CNTK v1 format. This will not really be needed - // when the new CNTK v2 model serialization format is ready. - /*static*/ const std::wstring CompositeFunction::InternalDefaultDynamicAxisName = L"*"; - /*static*/ const std::wstring CompositeFunction::InternalNoSequenceAxisName = L"__noSequenceAxis"; - - // Replace any PlaceHolder Variables in the graph of Functions underlying 'this' CompositeFunction. All PlaceHolder variables - // should have been replaced before performing any Forward compute of 'this' Function. - /*virtual*/ void CompositeFunction::ReplacePlaceholdersInPlace(const std::unordered_map& placeholderReplacements, - std::unordered_set& visitedFunctions, - std::unordered_set& replacedPlaceholders) - { - RootFunction()->ReplacePlaceholdersInPlace(placeholderReplacements, visitedFunctions, replacedPlaceholders); - - // If any of the placeholders were replaced with Output variables, let's add the graph of function underneath each of those to 'm_allPrimitiveFunctions' set - for (auto replacedPlaceholder : replacedPlaceholders) - { - auto replacingVariable = placeholderReplacements.at(replacedPlaceholder); - if (replacingVariable.IsOutput()) - { - auto ownerFunc = replacingVariable.Owner(); - std::unordered_set visitedFunctions; - Collect(ownerFunc, visitedFunctions); - - // Add the newly visited functions to 'm_allPrimitiveFunctions' set - m_allPrimitiveFunctions.insert(visitedFunctions.begin(), visitedFunctions.end()); - } - } - std::unordered_map functionVisitCounts; - - // An arbitrary cap on changing output shape of recurrent nodes, to detect infinite inference loops - const size_t maxNumValidationPassesAllowed = 25; - bool recurrentNodeOutputModified = false; - size_t numValidationPasses = 0; - do - { - recurrentNodeOutputModified = false; - functionVisitCounts.clear(); - RootFunction()->ValidateOrUpdateOutputs(functionVisitCounts, recurrentNodeOutputModified); - numValidationPasses++; - } while (recurrentNodeOutputModified && (numValidationPasses < maxNumValidationPassesAllowed)); - - if (numValidationPasses >= maxNumValidationPassesAllowed) - LogicError("A recurrent node output shape change happened in successive %d validation passes indicating a potential infinite inference loop!", (int)numValidationPasses); - } - - // Recursively create a sub-network of ComputationNode instances corresponding to the graph of Functions - // underlying the specified 'variable' and return the ComputationNode instance that corresponds to the - // top level 'variable' - template - /*static*/ ComputationNodeBasePtr CompositeFunction::GetNode(const Variable& variable, - Microsoft::MSR::CNTK::ComputationNetworkPtr& network, - ComputationNetworkBuilder& builder, - std::unordered_map& variableToNodeMap, - std::unordered_map& isVariableRootMap) - { - auto iter = variableToNodeMap.find(variable); - if (iter != variableToNodeMap.end()) - { - isVariableRootMap[variable] = false; - return iter->second; - } - - // The DataType, Shape and DynamicAxes of the variable must be known by now - if (variable.GetDataType() == DataType::Unknown) - InvalidArgument("Variable%S with unknown DataType detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str()); - - if (variable.Shape().IsUnknown()) - InvalidArgument("Variable%S with unknown shape detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str()); - - if (variable.Shape().HasInferredDimension()) - InvalidArgument("Variable%S with InferredDimension for at least one axis in its shape, detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str()); - - if (variable.DynamicAxes() == Axis::UnknownDynamicAxes()) - InvalidArgument("Variable%S with unknown dynamic axes detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str()); - - // Lets add a null entry in the map for this variable, to break infinite recursion when processing recurrent graphs - variableToNodeMap[variable] = nullptr; - - std::shared_ptr> computationNodePtr; - if (variable.IsParameter() || variable.IsConstant()) - { - auto internalNodeName = CNTKInternalNodeNameFromUidAndName(variable.Uid(), variable.Name()); - computationNodePtr = builder.CreateLearnableParameter(internalNodeName, AsTensorShape(variable.Shape())); - network->InitLearnableParameters(computationNodePtr, L"fixedValue", 0); // must call this to follow protocol; can overwrite later - if (!variable.NeedsGradient()) - computationNodePtr->SetLearningRateMultiplier(0.0); - - NDArrayViewPtr value = variable.IsConstant() ? Constant(variable).Value() : Parameter(variable).Value(); - std::shared_ptr> valueMatrix = variable.IsConstant() ? value->GetMatrix() : value->GetWritableMatrix(); - - if (variable.IsParameter() || (valueMatrix->GetDeviceId() == network->GetDeviceId())) - computationNodePtr->Value() = valueMatrix->AsReference(); - else - { - Matrix clonedMatrix(valueMatrix->GetNumRows(), valueMatrix->GetNumCols(), network->GetDeviceId(), valueMatrix->GetMatrixType(), valueMatrix->GetFormat()); - clonedMatrix.AssignValuesOf(*valueMatrix); - computationNodePtr->Value() = std::move(clonedMatrix); - } - } - else if (variable.IsInput()) - { - auto internalNodeName = CNTKInternalNodeNameFromUidAndName(variable.Uid(), variable.Name()); - - // TODO: Input variables currently are required to have the default batch axis - auto dynamicAxes = variable.DynamicAxes(); - auto foundDefaultBatchAxis = std::find(dynamicAxes.begin(), dynamicAxes.end(), Axis::DefaultBatchAxis()); - if (foundDefaultBatchAxis == dynamicAxes.end()) - LogicError("Currently Input Variables are required to have the DefaultBatchAxis as one of their dynamic axes"); - - if (dynamicAxes.back() != Axis::DefaultBatchAxis()) - LogicError("Currently Input Variables are required to have the DefaultBatchAxis as their last dynamic axes"); - - // TODO: Support inputs with > 1 dynamic axes - if ((dynamicAxes.size() < 1) || (dynamicAxes.size() > 2)) - LogicError("Currently only Input variables with 1 or 2 dynamic axis are supported"); - - // Construct the dynamic axis name to be used internally for the CNTK InputNodes - std::wstring internalDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes); - - if (!internalDynamicAxisName.empty() && !network->NodeNameExists(internalDynamicAxisName)) - network->AddNodeToNetAndAttachInputs(New>(network->GetDeviceId(), internalDynamicAxisName), {}); - - if (IsSparseInput(variable)) - computationNodePtr = builder.CreateSparseInputNode(internalNodeName, AsTensorShape(variable.Shape()), internalDynamicAxisName); - else - computationNodePtr = builder.CreateInputNode(internalNodeName, AsTensorShape(variable.Shape()), internalDynamicAxisName); - - if (variable.NeedsGradient()) - { - // Set a dummy learning rate multiplier to force gradient computation for the input computation node since by default - // gradients are not computed for Input nodes - computationNodePtr->SetLearningRateMultiplier(0.00001f); - } - } - else - { - assert(variable.IsOutput()); - computationNodePtr = GetOutputVariableNode(variable, network, builder, variableToNodeMap, isVariableRootMap)->template As>()->shared_from_this(); - } - - variableToNodeMap[variable] = computationNodePtr; - if (isVariableRootMap.find(variable) == isVariableRootMap.end()) - isVariableRootMap[variable] = variable.IsOutput(); - - return computationNodePtr; - } - - template - /*static*/ ComputationNodeBasePtr CompositeFunction::CreateComputationNode(const Variable& variable, - PrimitiveFunction* primitiveFunction, - const std::vector>>& inputNodes, - Microsoft::MSR::CNTK::ComputationNetworkPtr& network, - std::unordered_map& variableToNodeMap) - { - ComputationNodeBasePtr computationNodePtr; - - auto internalNodeName = CNTKInternalNodeNameFromUidAndName(primitiveFunction->Uid(), primitiveFunction->Name()); - - auto& functionConfig = primitiveFunction->Attributes(); - auto functionInputs = primitiveFunction->Inputs(); - PrimitiveOpType op = primitiveFunction->OpType(); - - switch (op) - { - case PrimitiveOpType::Negate: - computationNodePtr = New>(network->GetDeviceId(), internalNodeName); - break; - case PrimitiveOpType::Sigmoid: - computationNodePtr = New>(network->GetDeviceId(), internalNodeName); - break; - case PrimitiveOpType::Tanh: - computationNodePtr = New>(network->GetDeviceId(), internalNodeName); - break; - case PrimitiveOpType::ReLU: - computationNodePtr = New>(network->GetDeviceId(), internalNodeName); - break; - case PrimitiveOpType::Exp: - computationNodePtr = New>(network->GetDeviceId(), internalNodeName); - break; - case PrimitiveOpType::Log: - computationNodePtr = New>(network->GetDeviceId(), internalNodeName); - break; - case PrimitiveOpType::Sqrt: - computationNodePtr = New>(network->GetDeviceId(), internalNodeName); - break; - case PrimitiveOpType::Floor: - computationNodePtr = New>(network->GetDeviceId(), internalNodeName); - break; - case PrimitiveOpType::Abs: - computationNodePtr = New>(network->GetDeviceId(), internalNodeName); - break; - case PrimitiveOpType::Reciprocal: - computationNodePtr = New>(network->GetDeviceId(), internalNodeName); - break; - case PrimitiveOpType::Softmax: - computationNodePtr = New>(network->GetDeviceId(), internalNodeName); - break; - case PrimitiveOpType::Hardmax: - computationNodePtr = New>(network->GetDeviceId(), internalNodeName); - break; - case PrimitiveOpType::TransposeAxes: - { - auto axis1 = functionConfig[PrimitiveFunction::AttributeNameAxis1].Value(); - auto axis2 = functionConfig[PrimitiveFunction::AttributeNameAxis2].Value(); - - // The axis ids passed to the internal CNTK TransposeDimensionsNode are 1 based instead of 0 based - computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsCNTKInternalAxisIdx(axis1), AsCNTKInternalAxisIdx(axis2)); - break; - } - case PrimitiveOpType::Where: - { - auto dynamicAxes = variable.DynamicAxes(); - auto internalCNTKWhereNodeDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes); - computationNodePtr = New>(network->GetDeviceId(), internalNodeName, internalCNTKWhereNodeDynamicAxisName); - break; - } - case PrimitiveOpType::Slice: - { - auto axis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value(); - auto beginIndex = functionConfig[PrimitiveFunction::AttributeNameBeginIndex].Value(); - auto endIndex = functionConfig[PrimitiveFunction::AttributeNameEndIndex].Value(); - - // Internal CNTK SliceNode takes 1 based axis indices instead of 0 based - computationNodePtr = New>(network->GetDeviceId(), internalNodeName, beginIndex, endIndex, AsCNTKInternalAxisIdx(axis)); - break; - } - case PrimitiveOpType::RandomSample: - { - auto numSamples = functionConfig[PrimitiveFunction::AttributeNameNumSamples].Value(); - auto allowDuplicates = functionConfig[PrimitiveFunction::AttributeNameAllowDuplicates].Value(); - computationNodePtr = New>(network->GetDeviceId(), internalNodeName, numSamples, allowDuplicates); - break; - } - case PrimitiveOpType::RandomSampleInclusionFrequency: - { - auto numSamples = functionConfig[PrimitiveFunction::AttributeNameNumSamples].Value(); - auto allowDuplicates = functionConfig[PrimitiveFunction::AttributeNameAllowDuplicates].Value(); - computationNodePtr = New>(network->GetDeviceId(), internalNodeName, numSamples, allowDuplicates); - break; - } - case PrimitiveOpType::Dropout: - { - auto dropoutRate = functionConfig[PrimitiveFunction::AttributeNameDropoutRate].Value(); - computationNodePtr = New>(network->GetDeviceId(), internalNodeName); - computationNodePtr->As>()->SetDropoutRate(dropoutRate); - break; - } - case PrimitiveOpType::Reshape: - { - auto newShape = functionConfig[PrimitiveFunction::AttributeNameNewShape].Value(); - computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsTensorShape(newShape)); - break; - } - case PrimitiveOpType::ROIPooling: - { - auto roiOutputShape = functionConfig[PrimitiveFunction::AttributeNameROIOutputShape].Value(); - computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsTensorShape(roiOutputShape)); - break; - } - case PrimitiveOpType::Pooling: - { - PoolingType poolingType = (PoolingType)(functionConfig[PrimitiveFunction::AttributeNamePoolingType].Value()); - auto poolingWindowsShape = functionConfig[PrimitiveFunction::AttributeNamePoolingWindowShape].Value(); - auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value(); - auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value(); - auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value(); - auto autoPadding = AsVector(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value>()); - computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsCNTKPoolKind(poolingType), AsTensorShape(poolingWindowsShape), AsTensorShape(strides), autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), ImageLayoutKind::CHW); - break; - } - case PrimitiveOpType::SumAll: - computationNodePtr = New>(network->GetDeviceId(), internalNodeName); - break; - case PrimitiveOpType::Plus: - computationNodePtr = New>(network->GetDeviceId(), internalNodeName); - break; - case PrimitiveOpType::LogPlus: - computationNodePtr = New>(network->GetDeviceId(), internalNodeName); - break; - case PrimitiveOpType::Minus: - computationNodePtr = New>(network->GetDeviceId(), internalNodeName); - break; - case PrimitiveOpType::ElementTimes: - computationNodePtr = New>(network->GetDeviceId(), internalNodeName); - break; - case PrimitiveOpType::Equal: - computationNodePtr = New>(network->GetDeviceId(), internalNodeName); - break; - case PrimitiveOpType::NotEqual: - computationNodePtr = New>(network->GetDeviceId(), internalNodeName); - break; - case PrimitiveOpType::Less: - computationNodePtr = New>(network->GetDeviceId(), internalNodeName); - break; - case PrimitiveOpType::LessEqual: - computationNodePtr = New>(network->GetDeviceId(), internalNodeName); - break; - case PrimitiveOpType::Greater: - computationNodePtr = New>(network->GetDeviceId(), internalNodeName); - break; - case PrimitiveOpType::GreaterEqual: - computationNodePtr = New>(network->GetDeviceId(), internalNodeName); - break; - case PrimitiveOpType::Times: - { - size_t outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value(); - auto inferInputRankToMap = functionConfig[PrimitiveFunction::AttributeNameInferInputRankToMap].Value(); - computationNodePtr = New>(network->GetDeviceId(), internalNodeName, outputRank, inferInputRankToMap); - break; - } - case PrimitiveOpType::TransposeTimes: - { - size_t outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value(); - computationNodePtr = New>(network->GetDeviceId(), internalNodeName, outputRank); - break; - } - case PrimitiveOpType::Convolution: - { - NDShape outputMapCount, kernelShape; - std::tie(outputMapCount, kernelShape) = GetConvolutionOutputMapCountAndKernelShape(functionInputs[0].Shape(), functionInputs[1].Shape()); - auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value(); - auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value(); - auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value(); - auto sharing = AsVector(functionConfig[PrimitiveFunction::AttributeNameSharing].Value>()); - auto autoPadding = AsVector(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value>()); - auto transpose = functionConfig[PrimitiveFunction::AttributeNameTranspose].Value(); - auto maxTempMemSizeInSamples = functionConfig[PrimitiveFunction::AttributeNameMaxTempMemSizeInSamples].Value(); - computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsTensorShape(kernelShape), AsTensorShape(outputMapCount), AsTensorShape(strides), sharing, autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), transpose, ImageLayoutKind::CHW, maxTempMemSizeInSamples); - break; - } - case PrimitiveOpType::Logistic: - computationNodePtr = New>(network->GetDeviceId(), internalNodeName); - break; - case PrimitiveOpType::SquaredError: - computationNodePtr = New>(network->GetDeviceId(), internalNodeName); - break; - case PrimitiveOpType::CrossEntropyWithSoftmax: - computationNodePtr = New>(network->GetDeviceId(), internalNodeName); - break; - case PrimitiveOpType::ClassificationError: - computationNodePtr = New>(network->GetDeviceId(), internalNodeName); - break; - case PrimitiveOpType::PastValue: - case PrimitiveOpType::FutureValue: - { - Variable inputOperandVar = functionInputs[0]; - Variable initialStateVar = functionInputs[1]; - - size_t offset = primitiveFunction->Attributes()[PrimitiveFunction::AttributeNameOffset].Value(); - if (op == PrimitiveOpType::PastValue) - computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsTensorShape(inputOperandVar.Shape()), offset); - else - computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsTensorShape(inputOperandVar.Shape()), offset); - - break; - } - case PrimitiveOpType::ReduceElements: - { - auto reductionAxis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value(); - auto reductionOpName = functionConfig[PrimitiveFunction::AttributeNameReductionOpName].Value(); - computationNodePtr = New>(network->GetDeviceId(), internalNodeName, reductionOpName, AsCNTKInternalAxisIdx(reductionAxis)); - break; - } - case PrimitiveOpType::BatchNormalization: - { - auto spatial = functionConfig[PrimitiveFunction::AttributeNameSpatial].Value(); - auto normalizationTimeConstant = functionConfig[PrimitiveFunction::AttributeNameNormalizationTimeConstant].Value(); - auto blendTimeConstant = functionConfig[PrimitiveFunction::AttributeNameBlendTimeConstant].Value(); - auto epsilon = functionConfig[PrimitiveFunction::AttributeNameEpsilon].Value(); - auto useCuDNNEngine = functionConfig[PrimitiveFunction::AttributeNameUseCuDNNEngine].Value(); - - computationNodePtr = New>(network->GetDeviceId(), internalNodeName, spatial, normalizationTimeConstant, blendTimeConstant, epsilon, !useCuDNNEngine, ImageLayoutKind::CHW); - break; - } - case PrimitiveOpType::Combine: - // This operation is just a no-op and is a means to combine multiple functions to create a single Function - // whose outputs are a union of the outputs of the Functions being combined. - computationNodePtr = variableToNodeMap[variable]; - break; - case PrimitiveOpType::PackedIndex: - computationNodePtr = New>(network->GetDeviceId(), internalNodeName); - break; - case PrimitiveOpType::GatherPacked: - computationNodePtr = New>(network->GetDeviceId(), internalNodeName); - break; - case PrimitiveOpType::ScatterPacked: - computationNodePtr = New>(network->GetDeviceId(), internalNodeName); - break; - case PrimitiveOpType::Clip: - computationNodePtr = New>(network->GetDeviceId(), internalNodeName); - break; - case PrimitiveOpType::Select: - computationNodePtr = New>(network->GetDeviceId(), internalNodeName); - break; - case PrimitiveOpType::Splice: - { - Axis spliceAxis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value(); - computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsCNTKInternalAxisIdx(spliceAxis)); - break; - } - case PrimitiveOpType::OptimizedRNNStack: - { - auto bidirectional = functionConfig[PrimitiveFunction::AttributeNameBidirectional].Value(); - auto numLayers = functionConfig[PrimitiveFunction::AttributeNameNumLayers].Value(); - auto hiddenSize = functionConfig[PrimitiveFunction::AttributeNameHiddenSize].Value(); - auto recurrentOp = functionConfig[PrimitiveFunction::AttributeNameRecurrentOp].Value(); - - computationNodePtr = New>(network->GetDeviceId(), internalNodeName, bidirectional, numLayers, hiddenSize, recurrentOp); - break; - } - case PrimitiveOpType::ReconcileDynamicAxis: - { - computationNodePtr = New>(network->GetDeviceId(), internalNodeName); - break; - } - case PrimitiveOpType::LogSoftmax: - { - //This can be implemented as x => x - ReduceLogSum(x). How to do this here? - computationNodePtr = New>(network->GetDeviceId(), internalNodeName); - break; - } - default: - LogicError("Specified op %S not yet supported", PrimitiveOpTypeName(op).c_str()); - break; - } - - std::vector inputNodesBasePtrs; - for (auto inputNode : inputNodes) - inputNodesBasePtrs.push_back(inputNode); - - // Let's reorder inputNodesBasePtrs properly since the ordering of inputs of CNTK internal ComputationNode may be different from the PrimitiveFunction inputs ordering - ReorderAsCNTKComputationNodeInputs(op, inputNodesBasePtrs); - if (computationNodePtr->Is()) - { - auto computationNodeExpectedInputCount = computationNodePtr->As()->GetExpectedNumInputs(); - if (computationNodeExpectedInputCount != inputNodesBasePtrs.size()) - LogicError("Input count mismatch: The Primitive function for op %S has %d inputs while the corresponding ComputationNode has %d inputs", - PrimitiveOpTypeName(op).c_str(), - (int)inputNodesBasePtrs.size(), - (int)computationNodeExpectedInputCount); - } - - network->AddNodeToNetAndAttachInputs(computationNodePtr, inputNodesBasePtrs); - - return computationNodePtr; - } - - template - /*static*/ ComputationNodeBasePtr CompositeFunction::GetOutputVariableNode(const Variable& variable, - Microsoft::MSR::CNTK::ComputationNetworkPtr& network, - ComputationNetworkBuilder& builder, - std::unordered_map& variableToNodeMap, - std::unordered_map& isVariableRootMap) - { - assert(variable.IsOutput()); - - Function* function = variable.Owner().get(); - ComputationNodeBasePtr computationNodePtr; - if (dynamic_cast(function)) - { - PrimitiveFunction* primitiveFunction = dynamic_cast(function); - PrimitiveOpType op = primitiveFunction->OpType(); - auto& functionInputs = primitiveFunction->m_inputs; - - DataType nonConstInputDataType = DataType::Unknown; - for (auto& inputVar : functionInputs) - { - if (!inputVar.IsConstant() && (inputVar.GetDataType() != DataType::Unknown)) - { - nonConstInputDataType = inputVar.GetDataType(); - break; - } - } - - // Create the nodes corresponding to the inputs - std::vector>> inputNodes; - for (auto& inputVar : functionInputs) - { - // If the inputVar is a constant and not the right DataType let's coerce it to the right type - if (inputVar.IsConstant() && (nonConstInputDataType != DataType::Unknown) && (inputVar.GetDataType() != nonConstInputDataType)) - { - auto originalConstantValue = Constant(inputVar).Value(); - auto constantValueCPU = originalConstantValue->DeepClone(DeviceDescriptor::CPUDevice(), true); - NDArrayViewPtr newConstantValue = CloneAsDataType(constantValueCPU, nonConstInputDataType, true); - inputVar = Constant(newConstantValue->DeepClone(originalConstantValue->Device(), originalConstantValue->IsReadOnly()), inputVar.Name()); - } - - auto baseNodePtr = GetNode(inputVar, network, builder, variableToNodeMap, isVariableRootMap); - inputNodes.push_back((baseNodePtr != nullptr) ? baseNodePtr->template As>()->shared_from_this() : nullptr); - } - - computationNodePtr = CreateComputationNode(variable, primitiveFunction, inputNodes, network, variableToNodeMap); - if (op != PrimitiveOpType::Combine) - { - for (auto inputVar : functionInputs) - isVariableRootMap[inputVar] = false; - } - } - else - LogicError("User defined Functions are currently unsupported!"); - - return computationNodePtr; - } - - template - ComputationNetworkPtr CompositeFunction::GetComputationNetwork(const DeviceDescriptor& device, const std::unordered_set& backpropRoots, bool allocateNetworkMatrices) - { - if (m_computationNetwork != nullptr) - { - // TODO: We should either invalidate and readapt the network if he backpropRoots change compared to what was specified when the network - // was last constructed, to just recreate a new network. - // For now just disallow changing the backpropRoots after the network is created - if (!backpropRoots.empty() && (m_currentBackpropRoots != backpropRoots)) - LogicError("Changing backprop roots across different Forward calls on a CNTK composite Function is currently unsupported"); - - // TODO: Support changing the device across different invocations of the forward method on a Function instance - if (AsDeviceDescriptor(m_computationNetwork->GetDeviceId()) != device) - LogicError("Changing device across different Forward calls on a CNTK composite Function is currently unsupported"); - - } - else - { - m_computationNetwork = std::make_shared(AsCNTKImplDeviceId(device)); - - ComputationNetworkBuilder builder(*m_computationNetwork); - - // TODO: We currently only support one backprop root - if (backpropRoots.size() > 1) - LogicError("More than one backprop roots is currently unsupported"); - - auto placeholders = Placeholders(); - if (!placeholders.empty()) - InvalidArgument("All placeholders of a Function must be bound before performing a Forward computation on the Function!"); - - // Now recursively create the network in a top-down fashion - auto rootFunction = RootFunction(); - auto rootFunctionOutputs = rootFunction->Outputs(); - for (auto rootOutput : rootFunctionOutputs) - GetNode(rootOutput, m_computationNetwork, builder, m_variableToNodeMap, m_isVariableRootMap); - - // If any of the function outputs is not a root node, we need to explicitly add it to the 'output' group of the ComputationNetwork - for (auto rootOutput : rootFunctionOutputs) - { - if (!m_isVariableRootMap[rootOutput]) - m_computationNetwork->AddToNodeGroup(L"output", m_variableToNodeMap[rootOutput]); - } - - m_currentBackpropRoots = backpropRoots; - - // In case of recurrence, the inputs of some of the ComputationNodes are not attached due to cycles. - // Now attach those after we have created all ComputationNodes in the network - for (auto varNodePair : m_variableToNodeMap) - { - auto& currentComputationNode = varNodePair.second; - auto& currentComputationNodeInputs = currentComputationNode->GetInputs(); - auto& currentVar = varNodePair.first; - - if (std::find(currentComputationNodeInputs.begin(), currentComputationNodeInputs.end(), nullptr) != currentComputationNodeInputs.end()) - { - // This ComputationNode has at least one null input which now needs to be properly attached - - const PrimitiveFunction* primitiveFunc = dynamic_cast(currentVar.Owner().get()); - - // Let's reorder properly since the ordering of inputs of CNTK internal ComputationNode may be different from the PrimitiveFunction inputs ordering - auto inputVars = primitiveFunc->Inputs(); - ReorderAsCNTKComputationNodeInputs(primitiveFunc->OpType(), inputVars); - inputVars.resize(currentComputationNode->GetNumInputs()); - - std::vector inputNodesBasePtrs; - for (auto inputVar : inputVars) - inputNodesBasePtrs.push_back(m_variableToNodeMap[inputVar]); - - currentComputationNode->AttachInputs(inputNodesBasePtrs); - } - } - - m_computationNetwork->SetTraceLevel(Internal::GetComputationNetworkTraceLevel()); - m_computationNetwork->CompileNetwork(); - - // Verify that the shapes of the output Variables that we computed match the corresponding nodes in the ComputationNetwork - for (auto varNodePair : m_variableToNodeMap) - { - if (varNodePair.first.IsOutput()) - { - auto outputVar = varNodePair.first; - auto computationNodePtr = m_variableToNodeMap[outputVar]; - auto outputShape = outputVar.Shape(); - auto computationNodeSampleLayout = computationNodePtr->GetSampleLayout(); - if (((outputShape.Rank() == 0) && (computationNodeSampleLayout[0] != 1)) || - ((outputShape.Rank() != 0) && (computationNodeSampleLayout != AsTensorViewShape(outputShape)) && (computationNodeSampleLayout != AsTensorShape(outputShape)))) - { - LogicError("The output Variable shape %S does not match the SampleLayout shape %s of the corresponding ComputationNode in the network", outputShape.AsString().c_str(), ((std::string)computationNodeSampleLayout).c_str()); - } - } - } - - // Record the timestamps of Parameter values - assert(m_lastRecordedParameterValueTimeStamps.empty()); - auto functionParameters = Parameters(); - for (auto parameter : functionParameters) - m_lastRecordedParameterValueTimeStamps.insert({ parameter, parameter.CurrentValueTimeStamp() }); - } - - - if (!m_networkMatricesAllocated && allocateNetworkMatrices) - { - ComputationNodeBasePtr backpropRootNode; - - // Now recursively traverse the network in a top-down fashion - auto rootFunction = RootFunction(); - auto rootFunctionOutputs = rootFunction->Outputs(); - std::vector forwardRootNodes; - for (auto rootOutput : rootFunctionOutputs) - { - auto currentRootNode = m_variableToNodeMap[rootOutput]; - forwardRootNodes.push_back(currentRootNode); - - if (m_currentBackpropRoots.find(rootOutput) != m_currentBackpropRoots.end()) - backpropRootNode = currentRootNode; - } - - m_computationNetwork->AllocateAllMatrices(forwardRootNodes, {}, backpropRootNode); - m_networkMatricesAllocated = allocateNetworkMatrices; - } - - return m_computationNetwork; - } - - template - /*static*/ std::pair>, MBLayoutPtr> CompositeFunction::GetCNTKImplMatrixAndMBLayoutFromValueObject(Variable var, const ValuePtr& value) - { - if (var.GetDataType() != value->GetDataType()) - LogicError("The Variable's DataType %s does not match the corresponding Value's DataType %s", DataTypeName(var.GetDataType()), DataTypeName(value->GetDataType())); - - if (AsDataType() != value->GetDataType()) - LogicError("The specified ElementType %s does not match the DataType %s", typeid(ElementType).name(), DataTypeName(value->GetDataType())); - - // TODO: Is supplying dense data for an Input variable tagged as sparse, a fatal error? - if (IsSparseInput(var) && !value->IsSparse()) - InvalidArgument("Dense input data supplied for a sparse input Variable"); - - if (IsSparseInput(var) && (value->GetStorageFormat() != StorageFormat::SparseCSC)) - InvalidArgument("Sparse Input data must be in SparseCSC format"); - - auto varShape = var.Shape(); - auto valueShape = value->Shape(); - if (valueShape.Rank() < varShape.Rank()) - InvalidArgument("Value's rank should be >= the Variable's rank"); - - size_t maxAddionalValueAxes = std::max(2, var.DynamicAxes().size()); - if (valueShape.Rank() > (varShape.Rank() + maxAddionalValueAxes)) - InvalidArgument("Value rank should be larger than the Variable%S rank at most by number of dynamic axes", ParanthesizedName(var.Name()).c_str()); - - if (valueShape.SubShape(0, varShape.Rank()) != varShape) - { - InvalidArgument("The %s dimensions of the Value shape %S do not match the shape of the variable %S that it corresponds to!", - Internal::IsReversingTensorShapesInErrorMessagesEnabled() ? "trailing" : "leading", - AsStringForErrorReporting(valueShape).c_str(), - AsStringForErrorReporting(varShape).c_str()); - } - - if (var.DynamicAxes().empty()) - return{ value->Data()->GetMatrix(), nullptr }; - - if (var.DynamicAxes().size() > 2) - LogicError("More than 2 dynamic axis for a variable is currently unsupported"); - - auto mask = value->Mask(); - if ((mask != nullptr) && ((varShape.Rank() + mask->Shape().Rank()) != valueShape.Rank())) - InvalidArgument("Invalid Value object; the sum of the rank of the mask and data does not equal the Variable's rank + number of dynamic axes"); - - auto getNumTimeStepsAndSequencesFunc = [](const NDShape& maskShape) { - size_t maxNumTimeSteps = 1; - size_t numSequences = 1; - if (maskShape.Rank() > 0) - maxNumTimeSteps = maskShape[0]; - - if (maskShape.Rank() > 1) - numSequences = maskShape[1]; - - return std::pair(maxNumTimeSteps, numSequences); - }; - - size_t maxNumTimeSteps, numSequences; - std::tie(maxNumTimeSteps, numSequences) = getNumTimeStepsAndSequencesFunc(valueShape.SubShape(varShape.Rank())); - - auto getSequenceStartsAndLengthsFunc = [&getNumTimeStepsAndSequencesFunc](const NDMaskPtr& mask, std::vector& sequenceBeginIndices, std::vector& sequenceLengths) { - auto cpuMask = mask; - if (mask->Device() != DeviceDescriptor::CPUDevice()) - cpuMask = mask->DeepClone(DeviceDescriptor::CPUDevice()); - - const MaskKind* maskBuffer = cpuMask->DataBuffer(); - size_t maxNumTimeSteps, numSequences; - std::tie(maxNumTimeSteps, numSequences) = getNumTimeStepsAndSequencesFunc(mask->Shape()); - - for (size_t i = 0; i < numSequences; ++i) - { - MaskKind firstMaskEntry = maskBuffer[i * maxNumTimeSteps]; - if (firstMaskEntry == MaskKind::SequenceBegin) - sequenceBeginIndices[i] = 0; - else if (firstMaskEntry == MaskKind::Valid) - sequenceBeginIndices[i] = Microsoft::MSR::CNTK::SentinelValueIndicatingUnspecifedSequenceBeginIdx; - else - LogicError("The first entry of a mask should be Valid or SequenceBegin"); - - size_t currentSequenceLength = 1; - bool currentSequenceEndAlreadyFound = false; - for (size_t j = 1; j < maxNumTimeSteps; ++j) - { - if (maskBuffer[(i * maxNumTimeSteps) + j] == MaskKind::Invalid) - currentSequenceEndAlreadyFound = true; - else - { - if (currentSequenceEndAlreadyFound) - InvalidArgument("Invalid Value object; only trailing steps of a sequence can be masked"); - - currentSequenceLength++; - } - } - - sequenceLengths[i] = currentSequenceLength; - } - }; - - if ((numSequences == 1) || (maxNumTimeSteps == 1)) - { - // The data need not be shuffled - std::shared_ptr> matrixData = value->Data()->GetMatrix(varShape.Rank()); - auto layout = std::make_shared(); - if (!mask) - { - if (maxNumTimeSteps == 1) - layout->InitAsFrameMode(numSequences); - else - { - layout->Init(numSequences, maxNumTimeSteps); - layout->AddSequence(0, 0, 0, maxNumTimeSteps); - } - } - else - { - layout->Init(numSequences, maxNumTimeSteps); - - std::vector sequenceBeginIndices(numSequences, 0); - std::vector sequenceLengths(numSequences, maxNumTimeSteps); - getSequenceStartsAndLengthsFunc(mask, sequenceBeginIndices, sequenceLengths); - - for (size_t i = 0; i < numSequences; ++i) - layout->AddSequence(i, i, sequenceBeginIndices[i], sequenceLengths[i]); - } - - return{ matrixData , layout}; - } - else - { - std::vector sequenceBeginIndices(numSequences, 0); - std::vector sequenceLengths(numSequences, maxNumTimeSteps); - if (mask != nullptr) - getSequenceStartsAndLengthsFunc(mask, sequenceBeginIndices, sequenceLengths); - - bool hasTruncatedSequences = std::find_if(sequenceBeginIndices.begin(), sequenceBeginIndices.end(), [](const int& val) { return (val < 0); }) != sequenceBeginIndices.end(); - - auto layout = std::make_shared(); - std::vector> placement; - if (!hasTruncatedSequences) - { - std::vector sequences; - for (size_t i = 0; i < numSequences; ++i) - sequences.push_back({ i, SIZE_MAX, sequenceBeginIndices[i], sequenceLengths[i] }); - - std::vector rowAllocations; - layout->InitAsPackedSequences(sequences, placement, rowAllocations); - } - else - { - layout->Init(numSequences, maxNumTimeSteps); - - // We cannot pack as some of the sequences are truncated and thus all sequences have to be - // kept in their original parallel streams - placement.resize(numSequences); - for (size_t i = 0; i < numSequences; ++i) - { - layout->AddSequence(i, i, sequenceBeginIndices[i], sequenceLengths[i]); - - // Add the gap if there is one - if (sequenceLengths[i] < maxNumTimeSteps) - layout->AddSequence(GAP_SEQUENCE_ID, i, sequenceLengths[i], maxNumTimeSteps); - - placement[i] = std::make_pair(i, 0); - } - } - - if (maxNumTimeSteps != layout->GetNumTimeSteps()) - LogicError("The number of time steps in the packed MBLayout does not match the longest sequence's length in the Value object"); - - if (numSequences != layout->GetNumSequences()) - LogicError("The number of sequences in the packed MBLayout does not match the sequence count in the Value object"); - - // The data needs to be rearranged since CNTK requires sequences to be interleaved across timesteps - // Now generate the gather indices - auto matrixData = std::make_shared>(varShape.TotalSize(), - layout->GetNumCols(), - AsCNTKImplDeviceId(value->Device()), - value->IsSparse() ? MatrixType::SPARSE : MatrixType::DENSE, - AsCNTKImplMatrixFormat(value->GetStorageFormat())); - - std::vector sequencesShorterThanLongestSequence; - for (size_t i = 0; i < numSequences; ++i) - if (sequenceLengths[i] != maxNumTimeSteps) - sequencesShorterThanLongestSequence.push_back(i); - - // Set the source location for all gaps to be the last step of the first sequence that is shorter than the longest sequence in the batch - size_t sourceColIdxForInvalidColumns = sequencesShorterThanLongestSequence.empty() ? 0 : (((sequencesShorterThanLongestSequence[0] + 1) * maxNumTimeSteps) - 1); - std::vector gatherIndicesVector(layout->GetNumCols(), (ElementType)sourceColIdxForInvalidColumns); - for (size_t i = 0; i < numSequences; ++i) - { - size_t targetParallelStreamIdx = placement[i].first; - size_t targetStartIdxInParallelStream = placement[i].second; - for (size_t j = 0; j < sequenceLengths[i]; ++j) - gatherIndicesVector[((targetStartIdxInParallelStream + j) * layout->GetNumParallelSequences()) + targetParallelStreamIdx] = (ElementType)((i * maxNumTimeSteps) + j); - } - - auto gatherIdxMatrix = std::make_shared>(1, layout->GetNumCols(), gatherIndicesVector.data(), AsCNTKImplDeviceId(value->Device())); - matrixData->DoGatherColumnsOf(0, *gatherIdxMatrix, *(value->Data()->GetMatrix(varShape.Rank())), 1); - return{ matrixData, layout }; - } - } - - template - /*static*/ ValuePtr CompositeFunction::GetValueObjectFromCNTKImplMatrixAndMBLayout(const NDShape& sampleShape, const Matrix& matrix, const MBLayoutPtr& layout, bool readOnly /*= true*/) - { - NDShape valueDataShape = sampleShape; - - size_t maxNumTimeSteps = 1; - size_t numSequences = 1; - if (layout != nullptr) - { - maxNumTimeSteps = layout->GetNumTimeSteps(); - numSequences = layout->GetNumSequences(); - valueDataShape = valueDataShape.AppendShape({ maxNumTimeSteps, numSequences }); - } - - auto createMaskFunc = [](const MBLayoutPtr& layout, const DeviceDescriptor& device, std::vector& sequencesShorterThanLongestSequence) { - std::vector sequenceBeginFlags; - std::vector sequenceLengths; - sequencesShorterThanLongestSequence.clear(); - - size_t maxNumTimeSteps = layout->GetNumTimeSteps(); - size_t numSequences = layout->GetNumSequences(); - auto& layoutSequences = layout->GetAllSequences(); - - size_t sequenceIdx = 0; - bool allSequencesStartInThisMB = true; - bool allSequencesSameLength = true; - for (auto sequenceInfo : layoutSequences) - { - if (sequenceInfo.seqId != GAP_SEQUENCE_ID) - { - auto currentSequenceBeginIdx = std::max(0, sequenceInfo.tBegin); - auto currentSequenceEndIdx = std::min(maxNumTimeSteps, sequenceInfo.tEnd); - auto currentSequenceLength = (currentSequenceEndIdx - currentSequenceBeginIdx); - auto isCurrentSequenceBeginningInsideThisMB = sequenceInfo.tBegin >= 0; - - allSequencesStartInThisMB = allSequencesStartInThisMB && isCurrentSequenceBeginningInsideThisMB; - allSequencesSameLength = allSequencesSameLength && (currentSequenceLength == maxNumTimeSteps); - - sequenceBeginFlags.push_back(isCurrentSequenceBeginningInsideThisMB); - sequenceLengths.push_back(currentSequenceLength); - - if (currentSequenceLength != maxNumTimeSteps) - sequencesShorterThanLongestSequence.push_back(sequenceIdx); - - sequenceIdx++; - } - } - - if (!allSequencesStartInThisMB && (numSequences != layout->GetNumParallelSequences())) - LogicError("Cannot create an unpacked Value object from packed data where one or more sequences are truncated"); - - bool maskNeeded = !allSequencesSameLength || !allSequencesStartInThisMB; - - NDMaskPtr mask; - if (maskNeeded) - { - mask = MakeSharedObject(NDShape({ maxNumTimeSteps, numSequences }), DeviceDescriptor::CPUDevice()); - for (size_t i = 0; i < numSequences; ++i) - if (sequenceBeginFlags[i]) - mask->MarkSequenceBegin({0, i}); - - for (auto shortSequenceIdx : sequencesShorterThanLongestSequence) - mask->InvalidateSection({ sequenceLengths[shortSequenceIdx], shortSequenceIdx }, { NDShape::InferredDimension, 1 }); - } - - return mask; - }; - - // No data shuffling needed if no layout or the layout has just one time-step or just one sequence - std::vector sequencesShorterThanLongestSequence; - if ((maxNumTimeSteps == 1) || (numSequences == 1)) - { - // Just create a view over the existing matrix itself - auto tensorView = new TensorView(std::make_shared>(matrix.AsReference()), AsTensorViewShape(valueDataShape)); - auto data = MakeSharedObject(AsDataType(), AsDeviceDescriptor(matrix.GetDeviceId()), AsStorageFormat(matrix.GetFormat()), valueDataShape, readOnly, tensorView); - if (layout == nullptr) - return MakeSharedObject(data); - else - { - auto mask = createMaskFunc(layout, AsDeviceDescriptor(matrix.GetDeviceId()), sequencesShorterThanLongestSequence); - return MakeSharedObject(data, mask); - } - } - - if (layout->GetNumCols() != matrix.GetNumCols()) - LogicError("Bad MBLayout: The number of columns in the MBLayout does not match the number of columns in the data matrix!"); - - // Reshuffle to data to unpack and uninterleave the CNTK form packed data - // Now generate the scatter indices - auto shuffledMatrixData = std::make_shared>(matrix.GetNumRows(), maxNumTimeSteps * numSequences, matrix.GetDeviceId(), matrix.GetMatrixType(), matrix.GetFormat()); - auto mask = createMaskFunc(layout, AsDeviceDescriptor(matrix.GetDeviceId()), sequencesShorterThanLongestSequence); - - // Set the target location of all gaps to be the last step of the first sequence that is shorter than the longest sequence in the batch - size_t targetColIdxForInvalidColumns = sequencesShorterThanLongestSequence.empty() ? 0 : (((sequencesShorterThanLongestSequence[0] + 1) * maxNumTimeSteps) - 1); - std::vector scatterIndicesVector(layout->GetNumCols(), (ElementType)targetColIdxForInvalidColumns); - - size_t i = 0; - auto& layoutSequences = layout->GetAllSequences(); - for (auto sequenceInfo : layoutSequences) - { - if (sequenceInfo.seqId != GAP_SEQUENCE_ID) - { - size_t targetParallelStreamIdx = sequenceInfo.s; - auto currentSequenceBeginIdx = std::max(0, sequenceInfo.tBegin); - auto currentSequenceEndIdx = std::min(maxNumTimeSteps, sequenceInfo.tEnd); - size_t currentSequenceLength = (currentSequenceEndIdx - currentSequenceBeginIdx); - - for (size_t j = 0; j < currentSequenceLength; ++j) - scatterIndicesVector[((currentSequenceBeginIdx + j) * layout->GetNumParallelSequences()) + targetParallelStreamIdx] = (ElementType)((i * maxNumTimeSteps) + j); - - i++; - } - } - - auto scatterIdxMatrix = std::make_shared>(1, layout->GetNumCols(), scatterIndicesVector.data(), matrix.GetDeviceId()); - shuffledMatrixData->DoScatterColumnsOf(0, *scatterIdxMatrix, matrix, 1); - - auto tensorView = new TensorView(shuffledMatrixData, AsTensorViewShape(valueDataShape)); - auto data = MakeSharedObject(AsDataType(), AsDeviceDescriptor(matrix.GetDeviceId()), AsStorageFormat(shuffledMatrixData->GetFormat()), valueDataShape, readOnly, tensorView); - return MakeSharedObject(data, mask); - } - - template - /*static*/ ValuePtr CompositeFunction::GetValueObjectFromCNTKImplMatrixAndMBLayout(Variable var, const Matrix& matrix, const MBLayoutPtr& layout, bool readOnly /*= true*/) - { - if (var.DynamicAxes().size() > 2) - LogicError("More than 2 dynamic axis for a variable is currently unsupported"); - - if (AsDataType() != var.GetDataType()) - LogicError("The specified ElementType %s does not match the DataType %s", typeid(ElementType).name(), DataTypeName(var.GetDataType())); - - if ((layout != nullptr) && (matrix.GetNumRows() != var.Shape().TotalSize())) - LogicError("Unexpected matrix layout: The number of rows in the matrix does not match the sample size of the Variable"); - - return GetValueObjectFromCNTKImplMatrixAndMBLayout(var.Shape(), matrix, layout, readOnly); - } - - template - /*static*/ void CompositeFunction::PopulateComputationNodeValue(const std::pair& variableValue, ComputationNodeBasePtr& computationNode) - { - std::pair>, MBLayoutPtr> CNTKMatrixAndMBLayout; - auto packedValue = dynamic_cast(variableValue.second.get()); - if (packedValue) - CNTKMatrixAndMBLayout = packedValue->PackedData(); - else - CNTKMatrixAndMBLayout = GetCNTKImplMatrixAndMBLayoutFromValueObject(variableValue.first, variableValue.second); - - MBLayoutPtr layout = CNTKMatrixAndMBLayout.second; - - auto& nodeData = computationNode->As>()->Value(); - - // Switch the node matrix to the right matrix type - nodeData.AssignValuesOf(*CNTKMatrixAndMBLayout.first); - computationNode->GetMBLayout()->CopyFrom(layout); - } - - void CompositeFunction::PopulateNetworkInputs(const std::unordered_map& arguments) - { - std::vector inputNodes; - for (auto argumentValuePair : arguments) - { - auto argument = argumentValuePair.first; - auto argumentComputationNode = m_variableToNodeMap[argument]; - assert(argumentComputationNode); - inputNodes.push_back(argumentComputationNode); - - ValuePtr argumentValue = arguments.at(argument); - - MBLayoutPtr layout; - switch (argumentValue->GetDataType()) - { - case DataType::Float: - PopulateComputationNodeValue({ argument, argumentValue }, argumentComputationNode); - break; - case DataType::Double: - PopulateComputationNodeValue({ argument, argumentValue }, argumentComputationNode); - break; - default: - LogicError("Unsupported DataType %s", DataTypeName(argumentValue->GetDataType())); - break; - } - } - - m_computationNetwork->BumpEvalTimeStamp(inputNodes); - } - - template - /*static*/ void CompositeFunction::PopulateComputationNodeGradient(const std::pair& variableGradient, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode) - { - std::pair>, MBLayoutPtr> CNTKMatrixAndMBLayout; - auto packedValue = dynamic_cast(variableGradient.second.get()); - if (packedValue) - CNTKMatrixAndMBLayout = packedValue->PackedData(); - else - CNTKMatrixAndMBLayout = GetCNTKImplMatrixAndMBLayoutFromValueObject(variableGradient.first, variableGradient.second); - - MBLayoutPtr layout = CNTKMatrixAndMBLayout.second; - auto nodeLayout = computationNode->GetMBLayout(); - if (((layout == nullptr) != (nodeLayout == nullptr)) || ((layout != nullptr) && (*layout != *nodeLayout))) - InvalidArgument("The layout of the specified gradient Value is incompatible with the layout of the corresponding Variable computed during Forward call"); - computationNode->As>()->AssignGradient(*CNTKMatrixAndMBLayout.first); - } - - // Assign the supplied gradients corresponding to the root(s) of the network to be backpropagated through the graph - void CompositeFunction::PopulateNetworkGradients(const std::unordered_map& gradients) - { - auto functionOutputs = this->Outputs(); - for (auto gradientVarValuePair : gradients) - { - // Only gradients for roots of the function can be specified - if (std::find(functionOutputs.begin(), functionOutputs.end(), gradientVarValuePair.first) == functionOutputs.end()) - InvalidArgument("Gradients cannot be specified for a Variable that is not an Output of the Function"); - - auto outputComputationNode = m_variableToNodeMap[gradientVarValuePair.first]; - ValuePtr gradientValue = gradientVarValuePair.second; - - switch (gradientValue->GetDataType()) - { - case DataType::Float: - PopulateComputationNodeGradient(gradientVarValuePair, outputComputationNode); - break; - case DataType::Double: - PopulateComputationNodeGradient(gradientVarValuePair, outputComputationNode); - break; - default: - LogicError("Unsupported DataType %s", DataTypeName(gradientValue->GetDataType())); - break; - } - } - } - - static NDShape GetValueShape(const Variable& var, const ComputationNodeBasePtr& computationNodePtr) - { - size_t outputValueNumAxes = var.Shape().Rank(); - - // Add the batch and dynamic axes if needed - if (computationNodePtr->GetMBLayout() != nullptr) - outputValueNumAxes += 2; - - std::vector outputShapeDims(outputValueNumAxes); - for (size_t i = 0; i < var.Shape().Rank(); ++i) - outputShapeDims[i] = computationNodePtr->GetSampleLayout().GetDim(i); - - if (computationNodePtr->GetMBLayout() != nullptr) - { - outputShapeDims[var.Shape().Rank()] = computationNodePtr->GetMBLayout()->GetNumTimeSteps(); - outputShapeDims[var.Shape().Rank() + 1] = computationNodePtr->GetMBLayout()->GetNumSequences(); - } - - return NDShape(outputShapeDims); - } - - /*static*/ void CompositeFunction::GetNodeOutputOrGradient(Variable var, ValuePtr& varValue, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode, bool getGradient) - { - auto valueShape = GetValueShape(var, computationNode); - if (varValue != nullptr) - { - // TODO: The shape of the specified output Value object must match the actual output shape - if ((varValue->Shape() != valueShape) && (AsTensorShape(varValue->Shape()) != AsTensorShape(valueShape))) - InvalidArgument("The shape %S of the specified Value object for %s does not match the actual shape %S", AsStringForErrorReporting(varValue->Shape()).c_str(), getGradient ? "gradient" : "output", AsStringForErrorReporting(valueShape).c_str()); - } - - ValuePtr nodeValue; - auto layout = computationNode->GetMBLayout(); - switch (var.GetDataType()) - { - case DataType::Float: - { - auto& matrix = getGradient ? computationNode->As>()->Gradient() : computationNode->As>()->Value(); - if (varValue == nullptr) - nodeValue = MakeSharedObject(var.Shape(), std::make_shared>(matrix.AsReference()), layout, /*readOnly =*/ false); - else - nodeValue = GetValueObjectFromCNTKImplMatrixAndMBLayout(var, matrix, layout); - break; - } - case DataType::Double: - { - auto& matrix = getGradient ? computationNode->As>()->Gradient() : computationNode->As>()->Value(); - if (varValue == nullptr) - nodeValue = MakeSharedObject(var.Shape(), std::make_shared>(matrix.AsReference()), layout, /*readOnly =*/ false); - else - nodeValue = GetValueObjectFromCNTKImplMatrixAndMBLayout(var, matrix, layout); - break; - } - default: - LogicError("Unsupported DataType %s", DataTypeName(var.GetDataType())); - break; - } - - if (varValue == nullptr) - varValue = nodeValue; - else - varValue->CopyFrom(*nodeValue); - } - - void CompositeFunction::GetNetworkOutputs(std::unordered_map& outputs) - { - // Now copy the Forward values of output nodes from the network to outputs' Value objects - for (auto outputVarValuePair : outputs) - GetNodeOutputOrGradient(outputVarValuePair.first, outputs[outputVarValuePair.first], m_variableToNodeMap[outputVarValuePair.first], false /*getGradient*/); - } - - void CompositeFunction::GetNetworkGradients(std::unordered_map& gradients) - { - auto networkInputs = this->Inputs(); - // Now copy the gradient values of input nodes of the network to gradients' Value objects - for (auto gradientVarValuePair : gradients) - { - // Only gradients corresponding to inputs of the network can be obtained - if (std::find(networkInputs.begin(), networkInputs.end(), gradientVarValuePair.first) == networkInputs.end()) - InvalidArgument("Backpropagated gradient values can only be obtained for inputs of a Function"); - - // Gradients can only be obtained for parameter variables or input variables that NeedsGradient - if (!gradientVarValuePair.first.NeedsGradient()) - InvalidArgument("Gradient value incorrectly requested for an Output or Constant Variable, or an Input Variable with NeedsGradient setting of false"); - - auto computationNodePtr = m_variableToNodeMap[gradientVarValuePair.first]; - - if (!computationNodePtr->NeedsGradient()) - LogicError("Backpropagated gradient value cannot be read from a ComputationNode that has NeedsGradient set to false"); - - GetNodeOutputOrGradient(gradientVarValuePair.first, gradients[gradientVarValuePair.first], computationNodePtr, true /*getGradient*/); - } - } - - const std::vector& CompositeFunction::GetArgumentDependencies(const Variable& output) - { - assert(output.IsOutput()); - - auto iter = m_perOutputVarArgumentDependencies.find(output); - if (iter != m_perOutputVarArgumentDependencies.end()) - return iter->second; - - auto wrappedComposite = CompositeFunction::Create(output.Owner()); - m_perOutputVarArgumentDependencies[output] = wrappedComposite->Arguments(); - - return m_perOutputVarArgumentDependencies[output]; - } - - /*virtual*/ BackPropStatePtr CompositeFunction::Forward(const std::unordered_map& arguments, - std::unordered_map& outputs, - const DeviceDescriptor& computeDevice, - const std::unordered_set& outputsToRetainBackwardStateFor) - { - // Validate arguments and outputs - if (outputs.empty()) - InvalidArgument("CompositeFunction::Forward: At least one output has to be specified!"); - - // Make sure that the DataType of the variables and corresponding values match - // TODO: We need a better way to determine the ElementType for the network - auto dataType = DataType::Unknown; - for (auto variableValuePair : arguments) - { - if (dataType == DataType::Unknown) - dataType = variableValuePair.first.GetDataType(); - else if (dataType != variableValuePair.first.GetDataType()) - LogicError("CompositeFunction::Forward: The DataType of all arguments of the Function must be same"); - } - - if (dataType == DataType::Unknown) - { - for (auto variableValuePair : outputs) - { - if (dataType == DataType::Unknown) - dataType = variableValuePair.first.GetDataType(); - } - } - - if (dataType == DataType::Float) - GetComputationNetwork(computeDevice, outputsToRetainBackwardStateFor, true); - else if (dataType == DataType::Double) - GetComputationNetwork(computeDevice, outputsToRetainBackwardStateFor, true); - else - InvalidArgument("Unsupported DataType %s", DataTypeName(dataType)); - - std::unordered_set functionOutputs(this->Outputs().begin(), this->Outputs().end()); - std::vector outputsToEvaluate; - std::unordered_set requiredArguments; - for (auto outputVarValuePair : outputs) - { - // Ensure that only a subset of this function's outputs are being asked to be evaluated - if (functionOutputs.find(outputVarValuePair.first) == functionOutputs.end()) - InvalidArgument("Requested output is not an Ouptut of the Function"); - - auto& requiredArgumentsForCurrentOutput = GetArgumentDependencies(outputVarValuePair.first); - requiredArguments.insert(requiredArgumentsForCurrentOutput.begin(), requiredArgumentsForCurrentOutput.end()); - - auto outputComputationNode = m_variableToNodeMap[outputVarValuePair.first]; - outputsToEvaluate.push_back(outputComputationNode); - } - - // TODO: Avoid copying the data when possible - - // We should have argument values supplied for all required argument dependencies for the requested outputs - for (auto requiredArgument : requiredArguments) - { - if (arguments.find(requiredArgument) == arguments.end()) - InvalidArgument("Function::Forward: Required argument's (%S) value that the requested output(s) depend on has not been provided", requiredArgument.Name().c_str()); - } - - // Feed data into the arguments of the network - PopulateNetworkInputs(arguments); - - // Dropout nodes have an implicit input in the form of the random mask that is applied to its explicit input - // This mask is regerated every minibatch and hence dropout nodes with a non-zero dropout rate must me marked outdated - // w.r.t. inputs to force evaluation in each minibatch - list dropoutNodes = m_computationNetwork->GetNodesWithType(OperationNameOf(DropoutNode)); - for (auto& nodeIter : dropoutNodes) - nodeIter->SetEvalTimeStampOutdatedWrtAll(); - - // Bump the timestamp of the parameter nodes whose values have changed - for (auto& paramTimeStampRecord : m_lastRecordedParameterValueTimeStamps) - { - auto parameter = paramTimeStampRecord.first; - auto prevTimeStamp = paramTimeStampRecord.second; - auto newTimeStamp = parameter.CurrentValueTimeStamp(); - if (newTimeStamp > prevTimeStamp) - { - paramTimeStampRecord.second = newTimeStamp; - m_variableToNodeMap[parameter]->BumpEvalTimeStamp(); - } - } - - // The 'outputsToRetainBackwardStateFor' nodes also need to be evaluated if not already specified in 'outputs' - for (auto rootVarForBackprop : outputsToRetainBackwardStateFor) - { - if (functionOutputs.find(rootVarForBackprop) == functionOutputs.end()) - InvalidArgument("Requested outputs to retain backward state for is not an Ouptut of the Function"); - - if (outputs.find(rootVarForBackprop) == outputs.end()) - outputsToEvaluate.push_back(m_variableToNodeMap[rootVarForBackprop]); - } - - // TODO: Verify that values were supplied for all inputs that requested outputs depend on - - ScopedNetworkOperationMode modeGuard(m_computationNetwork, outputsToRetainBackwardStateFor.empty() ? NetworkOperationMode::inferring : NetworkOperationMode::training); - - m_computationNetwork->ForwardProp(outputsToEvaluate); - - GetNetworkOutputs(outputs); - - // TODO: How to deal with the specified 'computeDevice' - Variable evalTimeStampVariable; - if (arguments.empty()) - evalTimeStampVariable = Inputs()[0]; - else - evalTimeStampVariable = arguments.begin()->first; - - return (outputsToRetainBackwardStateFor.size() > 0) ? MakeSharedObject(this->shared_from_this(), computeDevice, std::make_pair(evalTimeStampVariable, m_variableToNodeMap[evalTimeStampVariable]->GetEvalTimeStamp())) : nullptr; - } - - /*virtual*/ void CompositeFunction::Backward(const BackPropStatePtr& state, - const std::unordered_map& rootGradientValues, - std::unordered_map& backPropagatedGradientValuesForInputs) - { - auto backpropState = dynamic_cast(state.get()); - if (backpropState == nullptr) - InvalidArgument("Invalid backprop state specified"); - - // TODO: Support multiple concurrent backprop states - if (backpropState->EvalTimeStamp().second != m_variableToNodeMap[backpropState->EvalTimeStamp().first]->GetEvalTimeStamp()) - LogicError("The specified backprop state specified cannot be used for backpropagation as the Function's internal state was modified by subsequent Forward calls to the function." - "This is not a user error but a shortcoming of the current implementation where multiple independent backprop states are not simultaneously supported"); - - if (rootGradientValues.size() > 1) - LogicError("Currently gradient backprop from only one of the Function Outputs is supported"); - - // TODO: Avoid copying the data when possible - - // Zero all gradients of nodes below the root nodes - for (auto rootGradientVarValuePair : rootGradientValues) - m_computationNetwork->ZeroInputGradients(m_variableToNodeMap[rootGradientVarValuePair.first]); - - // Feed data into the arguments of the network - PopulateNetworkGradients(rootGradientValues); - - // Backpropagate through the network - ScopedNetworkOperationMode modeGuard(m_computationNetwork, NetworkOperationMode::training); - - auto rootComputationNodePtr = m_variableToNodeMap[rootGradientValues.begin()->first]; - m_computationNetwork->GetNestedNetwork(rootComputationNodePtr)->Backprop(FrameRange(nullptr), true, true); - - GetNetworkGradients(backPropagatedGradientValuesForInputs); - - // TODO: How to deal with the specified 'computeDevice' - } - FunctionPtr UnaryOp(PrimitiveOpType op, const Variable& operand, Dictionary&& opConfig, const std::wstring& name) { std::vector operands = { operand }; diff --git a/Source/CNTKv2LibraryDll/MinibatchSource.cpp b/Source/CNTKv2LibraryDll/MinibatchSource.cpp index c8b2afba2..09b40b837 100644 --- a/Source/CNTKv2LibraryDll/MinibatchSource.cpp +++ b/Source/CNTKv2LibraryDll/MinibatchSource.cpp @@ -10,7 +10,6 @@ #include "MinibatchSource.h" #include "HeapMemoryProvider.h" #include "ReaderShim.h" -#include "Function.h" #include #include "Value.h" #include "MPIWrapper.h" diff --git a/Source/CNTKv2LibraryDll/PrimitiveFunction.cpp b/Source/CNTKv2LibraryDll/PrimitiveFunction.cpp new file mode 100644 index 000000000..f6f6ed175 --- /dev/null +++ b/Source/CNTKv2LibraryDll/PrimitiveFunction.cpp @@ -0,0 +1,654 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE.md file in the project root for full license information. +// + +#include "stdafx.h" +#include "PrimitiveFunction.h" +#include "ComputationNode.h" +#include "ReshapingNodes.h" +#include "EvaluationNodes.h" +#include "TrainingNodes.h" +#include "LinearAlgebraNodes.h" +#include "InputAndParamNodes.h" +#include "NonlinearityNodes.h" +#include "RecurrentNodes.h" +#include "Serialization.h" +#include "RNNNodes.h" + +using namespace Microsoft::MSR::CNTK; + +namespace CNTK +{ + // Names for the reduction operations as used by the CNTK ReduceElementsNode + /*static*/ const std::wstring PrimitiveFunction::InternalSumReductionOpName = L"Sum"; + /*static*/ const std::wstring PrimitiveFunction::InternalLogSumReductionOpName = L"LogSum"; + /*static*/ const std::wstring PrimitiveFunction::InternalMeanReductionOpName = L"Mean"; + /*static*/ const std::wstring PrimitiveFunction::InternalMaxReductionOpName = L"Max"; + /*static*/ const std::wstring PrimitiveFunction::InternalMinReductionOpName = L"Min"; + /*static*/ const std::wstring PrimitiveFunction::InternalAllReductionOpName = L"All"; + /*static*/ const std::wstring PrimitiveFunction::InternalAnyReductionOpName = L"Any"; + + // Names of the various attributes of CNTK primitive Functions + /*static*/ const std::wstring PrimitiveFunction::AttributeNameAxis = L"axis"; + /*static*/ const std::wstring PrimitiveFunction::AttributeNameAxis1 = L"axis1"; + /*static*/ const std::wstring PrimitiveFunction::AttributeNameAxis2 = L"axis2"; + /*static*/ const std::wstring PrimitiveFunction::AttributeNameAllowDuplicates = L"allowDuplicates"; + /*static*/ const std::wstring PrimitiveFunction::AttributeNameNumSamples = L"numSamples"; + /*static*/ const std::wstring PrimitiveFunction::AttributeNameDropoutRate = L"dropoutRate"; + /*static*/ const std::wstring PrimitiveFunction::AttributeNameNewShape = L"newShape"; + /*static*/ const std::wstring PrimitiveFunction::AttributeNameOutputRank = L"outputRank"; + /*static*/ const std::wstring PrimitiveFunction::AttributeNameInferInputRankToMap = L"inferInputRankToMap"; + /*static*/ const std::wstring PrimitiveFunction::AttributeNameOffset = L"offset"; + /*static*/ const std::wstring PrimitiveFunction::AttributeNameStrides = L"strides"; + /*static*/ const std::wstring PrimitiveFunction::AttributeNameSharing = L"sharing"; + /*static*/ const std::wstring PrimitiveFunction::AttributeNameAutoPadding = L"autoPadding"; + /*static*/ const std::wstring PrimitiveFunction::AttributeNameLowerPad = L"lowerPad"; + /*static*/ const std::wstring PrimitiveFunction::AttributeNameUpperPad = L"upperPad"; + /*static*/ const std::wstring PrimitiveFunction::AttributeNameTranspose = L"transpose"; + /*static*/ const std::wstring PrimitiveFunction::AttributeNameMaxTempMemSizeInSamples = L"maxTempMemSizeInSamples"; + /*static*/ const std::wstring PrimitiveFunction::AttributeNameROIOutputShape = L"roiOutputShape"; + /*static*/ const std::wstring PrimitiveFunction::AttributeNamePoolingType = L"poolingType"; + /*static*/ const std::wstring PrimitiveFunction::AttributeNamePoolingWindowShape = L"poolingWindowShape"; + /*static*/ const std::wstring PrimitiveFunction::AttributeNameSpatial = L"spatial"; + /*static*/ const std::wstring PrimitiveFunction::AttributeNameNormalizationTimeConstant = L"normalizationTimeConstant"; + /*static*/ const std::wstring PrimitiveFunction::AttributeNameBlendTimeConstant = L"blendTimeConstant"; + /*static*/ const std::wstring PrimitiveFunction::AttributeNameEpsilon = L"epsilon"; + /*static*/ const std::wstring PrimitiveFunction::AttributeNameUseCuDNNEngine = L"useCuDNNEngine"; + /*static*/ const std::wstring PrimitiveFunction::AttributeNameNewDynamicAxes = L"newDynamicAxes"; + /*static*/ const std::wstring PrimitiveFunction::AttributeNameNewSequenceAxisLengthScalingFactor = L"newSequenceAxisLengthScalingFactor"; + /*static*/ const std::wstring PrimitiveFunction::AttributeNameNewSequenceAxisLengthAdditiveFactor = L"newSequenceAxisLengthAdditiveFactor"; + /*static*/ const std::wstring PrimitiveFunction::AttributeNameBeginIndex = L"beginIndex"; + /*static*/ const std::wstring PrimitiveFunction::AttributeNameEndIndex = L"endIndex"; + /*static*/ const std::wstring PrimitiveFunction::AttributeNameReductionOpName = L"reductionOpName"; + /*static*/ const std::wstring PrimitiveFunction::AttributeNameBidirectional = L"bidirectional"; + /*static*/ const std::wstring PrimitiveFunction::AttributeNameNumLayers = L"numLayers"; + /*static*/ const std::wstring PrimitiveFunction::AttributeNameHiddenSize = L"hiddenSize"; + /*static*/ const std::wstring PrimitiveFunction::AttributeNameRecurrentOp = L"recurrentOp"; + + /*static*/ std::vector PrimitiveFunction::GetOutputVariables(PrimitiveOpType op, + std::vector& inputs, + Function* owner, + Dictionary& functionConfig, + bool inferDimensions, + const std::wstring& functionName) + { + if (op == PrimitiveOpType::Combine) + return inputs; + + // We use the first non-constant input operand's DataType as the output DataType + // In case there are no non-constant known DataTypes, we just pick the first known operand DataType + // Also, all the known DataTypes of operands should match except for constants where coercion is allowed + DataType firstKnownInputDataType = DataType::Unknown; + DataType outputDataType = DataType::Unknown; + size_t i = 0; + while (i < inputs.size()) + { + auto input = inputs[i++]; + auto inputDataType = input.GetDataType(); + if (inputDataType != DataType::Unknown) + { + if (firstKnownInputDataType == DataType::Unknown) + firstKnownInputDataType = inputDataType; + + if (outputDataType == DataType::Unknown) + { + if (!input.IsConstant()) + outputDataType = inputDataType; + } + else + { + // The DataType of all operands should match except for Constants where we allow coercion + if ((inputDataType != DataType::Unknown) && (inputDataType != outputDataType) && !input.IsConstant()) + InvalidArgument("Primitive function with op type %S has operands with different DataTypes %s and %s", PrimitiveOpTypeName(op).c_str(), DataTypeName(outputDataType), DataTypeName(inputDataType)); + } + } + } + + if (outputDataType == DataType::Unknown) + outputDataType = firstKnownInputDataType; + + // We currently require that the inputs' dynamic axes, if any, match + std::vector outputDynamicAxes; + if ((op == PrimitiveOpType::SumAll) || + (op == PrimitiveOpType::SquaredError) || + (op == PrimitiveOpType::CrossEntropyWithSoftmax) || + (op == PrimitiveOpType::ClassificationError) || + (op == PrimitiveOpType::Logistic)) + { + outputDynamicAxes = std::vector({}); + } + else if (op == PrimitiveOpType::Where) + { + if (functionConfig.Contains(PrimitiveFunction::AttributeNameNewDynamicAxes)) + outputDynamicAxes = AsVector(functionConfig[PrimitiveFunction::AttributeNameNewDynamicAxes].Value>()); + else + { + if (inputs[0].DynamicAxes() == Axis::UnknownDynamicAxes()) + outputDynamicAxes = Axis::UnknownDynamicAxes(); + else + { + if (functionConfig.Contains(PrimitiveFunction::AttributeNameNewSequenceAxisLengthScalingFactor) && + functionConfig.Contains(PrimitiveFunction::AttributeNameNewSequenceAxisLengthAdditiveFactor)) + { + size_t newSequenceAxisLengthScalingFactor = functionConfig[PrimitiveFunction::AttributeNameNewSequenceAxisLengthScalingFactor].Value(); + int newSequenceAxisLengthAdditiveFactor = functionConfig[PrimitiveFunction::AttributeNameNewSequenceAxisLengthAdditiveFactor].Value(); + + auto derivedDynamicAxes = GetDerivedDynamicAxes(inputs[0].DynamicAxes()[0], newSequenceAxisLengthScalingFactor, newSequenceAxisLengthAdditiveFactor); + std::copy(derivedDynamicAxes.begin(), derivedDynamicAxes.end(), std::back_inserter(outputDynamicAxes)); + } + else + { + outputDynamicAxes.push_back(Axis::NewUniqueDynamicAxis(L"whereNodeDynamicAxis")); + } + + for (size_t i = 1; i < inputs[0].DynamicAxes().size(); ++i) + outputDynamicAxes.push_back(inputs[0].DynamicAxes()[i]); + + functionConfig[PrimitiveFunction::AttributeNameNewDynamicAxes] = AsDictionaryValueVector(outputDynamicAxes); + } + } + } + else if (op == PrimitiveOpType::ScatterPacked) + outputDynamicAxes = inputs[2].DynamicAxes(); + else if ((op == PrimitiveOpType::PackedIndex) || (op == PrimitiveOpType::GatherPacked)) + outputDynamicAxes = inputs[1].DynamicAxes(); + else if (op == PrimitiveOpType::ReconcileDynamicAxis) + outputDynamicAxes = inputs[1].DynamicAxes(); + else + { + auto allInputDynamicAxesEmpty = std::find_if(inputs.begin(), inputs.end(), [](const Variable& input) { return !input.DynamicAxes().empty(); }) == inputs.end(); + if (!allInputDynamicAxesEmpty) + { + outputDynamicAxes = Axis::UnknownDynamicAxes(); + for (auto inputVar : inputs) + { + auto currentInputDynamicAxes = inputVar.DynamicAxes(); + if (!currentInputDynamicAxes.empty() && (currentInputDynamicAxes != Axis::UnknownDynamicAxes())) + { + if (outputDynamicAxes == Axis::UnknownDynamicAxes()) + outputDynamicAxes = currentInputDynamicAxes; + else + { + if (currentInputDynamicAxes != outputDynamicAxes) + LogicError("Currently if an operand of a elementwise operation has any dynamic axes, those must match the dynamic axes of the other operands"); + } + } + } + } + } + + NDShape outputShape; + bool areAnyInputShapesUnknown = (std::find_if(inputs.begin(), inputs.end(), [](const Variable& input) { return input.Shape().IsUnknown(); }) != inputs.end()); + if (areAnyInputShapesUnknown) + outputShape = NDShape::Unknown; + else + { + switch (op) + { + case PrimitiveOpType::Negate: + case PrimitiveOpType::Sigmoid: + case PrimitiveOpType::Tanh: + case PrimitiveOpType::ReLU: + case PrimitiveOpType::Exp: + case PrimitiveOpType::Log: + case PrimitiveOpType::Sqrt: + case PrimitiveOpType::Floor: + case PrimitiveOpType::Abs: + case PrimitiveOpType::Reciprocal: + case PrimitiveOpType::Softmax: + case PrimitiveOpType::Hardmax: + case PrimitiveOpType::Dropout: + case PrimitiveOpType::Where: + case PrimitiveOpType::LogSoftmax: + { + assert(inputs.size() == 1); + outputShape = UnaryElementwiseOpOutputShape(inputs[0].Shape()); + break; + } + case PrimitiveOpType::TransposeAxes: + { + assert(inputs.size() == 1); + + auto axis1 = NormalizeStaticAxis(functionConfig[PrimitiveFunction::AttributeNameAxis1].Value(), inputs[0].Shape()); + auto axis2 = NormalizeStaticAxis(functionConfig[PrimitiveFunction::AttributeNameAxis2].Value(), inputs[0].Shape()); + + if (!axis1.IsStaticAxis() || !axis2.IsStaticAxis()) + LogicError("TransposeAxes operation currently does not support transposing dynamic axes"); + + VerifyStaticAxis(axis1, inputs[0].Shape()); + VerifyStaticAxis(axis2, inputs[0].Shape()); + + outputShape = inputs[0].Shape(); + std::swap(outputShape[axis1.StaticAxisIndex()], outputShape[axis2.StaticAxisIndex()]); + break; + } + case PrimitiveOpType::Slice: + { + assert(inputs.size() == 1); + auto axis = NormalizeStaticAxis(functionConfig[PrimitiveFunction::AttributeNameAxis].Value(), inputs[0].Shape()); + + auto beginIndex = functionConfig[PrimitiveFunction::AttributeNameBeginIndex].Value(); + auto endIndex = functionConfig[PrimitiveFunction::AttributeNameEndIndex].Value(); + if (!axis.IsStaticAxis()) + LogicError("Built-in Slice operation currently does not support slicing along dynamic axis"); + + VerifyStaticAxis(axis, inputs[0].Shape()); + + size_t sliceAxisDim = inputs[0].Shape()[axis.StaticAxisIndex()]; + int realBeginIndex = (beginIndex >= 0) ? beginIndex : beginIndex + sliceAxisDim; + int realEndIndex = (endIndex > 0) ? endIndex : endIndex + sliceAxisDim; + if ((sliceAxisDim < realEndIndex) || (realEndIndex < realBeginIndex) || (realBeginIndex < 0)) + RuntimeError("Slice operation: Index range [%d,%d), interpreted as [%d,%d), is invalid for input's shape ([%S]).", + beginIndex, + endIndex, + realBeginIndex, + realEndIndex, + AsStringForErrorReporting(inputs[0].Shape()).c_str()); + + auto outputTensorShape = AsTensorShape(inputs[0].Shape()); + + // propagate as much as we can + if ((axis.StaticAxisIndex() < (int)outputTensorShape.GetRank()) && (0 <= realBeginIndex) && (realBeginIndex <= realEndIndex) && (realEndIndex <= sliceAxisDim)) + outputTensorShape.NarrowTo(axis.StaticAxisIndex(), realBeginIndex, realEndIndex); + + outputShape = AsNDShape(outputTensorShape, /*allowNonFlattenableTensorShapes = */ true); + break; + } + case PrimitiveOpType::Reshape: + { + auto newShape = functionConfig[PrimitiveFunction::AttributeNameNewShape].Value(); + outputShape = ReshapeOutputShape(inputs[0].Shape(), newShape); + break; + } + case PrimitiveOpType::ROIPooling: + { + assert(inputs.size() == 2); + auto convMapShape = inputs[0].Shape(); + auto roisShape = inputs[1].Shape(); + auto roiOutputShape = functionConfig[PrimitiveFunction::AttributeNameROIOutputShape].Value(); + + auto outW = roiOutputShape[0]; + auto outH = roiOutputShape[1]; + auto numChannels = convMapShape[2]; + auto roisPerImage = roisShape[1]; + + if (roiOutputShape.Rank() != 2) + InvalidArgument("ROIPoolingNode: roi output shape must have two dimensions ([W x H])."); + + if (convMapShape[0] < outW || convMapShape[1] < outH) + InvalidArgument("ROIPoolingNode: inputWidth must >= windowWidth and inputHeight must >= windowHeight."); + + if (convMapShape[2] < 1) + InvalidArgument("ROIPoolingNode: input must have at least one channel ([W x H x C])."); + + if (roisShape[0] != 4) + InvalidArgument("ROIPoolingNode: ROI input must have the following shape: [4 x roisPerImage]."); + + if (roisPerImage < 1) + InvalidArgument("ROIPoolingNode: ROI input must contain at least one ROI ([4 x roisPerImage])."); + + outputShape = { outW, outH, numChannels, roisPerImage }; + break; + } + case PrimitiveOpType::Pooling: + { + assert(inputs.size() == 1); + auto poolingWindowsShape = functionConfig[PrimitiveFunction::AttributeNamePoolingWindowShape].Value(); + auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value(); + auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value(); + auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value(); + auto autoPadding = AsVector(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value>()); + NDShape outputMapCount = { 1 }; + std::vector sharing = { true }; + auto inputShape = inputs[0].Shape(); + + // In case of pooling if the kernel shape is unknown, then treat it as global pooling. + if (poolingWindowsShape == NDShape::Unknown) + { + if ((std::find(autoPadding.begin(), autoPadding.end(), true) != autoPadding.end()) || + (lowerPad.TotalSize() > 0) || (upperPad.TotalSize() > 0)) + RuntimeError("Padding isn't allowed for Unknown shape!"); + + poolingWindowsShape = inputShape.SubShape(0, inputShape.Rank()-1); + functionConfig[PrimitiveFunction::AttributeNamePoolingWindowShape] = poolingWindowsShape; + } + + outputShape = ConvolutionOpOutputShape(op, inputShape, poolingWindowsShape, outputMapCount, strides, sharing, autoPadding, lowerPad, upperPad, false, inferDimensions); + break; + } + case PrimitiveOpType::SumAll: + assert(inputs.size() == 1); + outputShape = {1}; + break; + case PrimitiveOpType::Plus: + case PrimitiveOpType::LogPlus: + case PrimitiveOpType::Minus: + case PrimitiveOpType::ElementTimes: + case PrimitiveOpType::Equal: + case PrimitiveOpType::NotEqual: + case PrimitiveOpType::Less: + case PrimitiveOpType::LessEqual: + case PrimitiveOpType::Greater: + case PrimitiveOpType::GreaterEqual: + case PrimitiveOpType::PastValue: + case PrimitiveOpType::FutureValue: + { + assert(inputs.size() == 2); + if ((op == PrimitiveOpType::PastValue) || (op == PrimitiveOpType::FutureValue)) + { + Variable inputOperandVar = inputs[0]; + Variable initialStateVar = inputs[1]; + + // TODO: We currently only support input operand with 1 dynamic axis for PastValue/FutureValue + if ((inputOperandVar.DynamicAxes() != Axis::UnknownDynamicAxes()) && (inputOperandVar.DynamicAxes().size() != 2)) + LogicError("Currently PastValue/FutureValue Function only supports input operand with 2 dynamic axis (1 sequence-axis and 1 batch-axis)"); + + if (!initialStateVar.DynamicAxes().empty()) + LogicError("Currently PastValue/FutureValue Function does not support initial state operand with dynamic axes!"); + } + + outputShape = BinaryElementwiseOpOutputShape(op, inputs[0], inputs[1], true, inferDimensions); + break; + } + case PrimitiveOpType::Times: + { + assert(inputs.size() == 2); + auto outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value(); + auto inferInputRankToMap = functionConfig[PrimitiveFunction::AttributeNameInferInputRankToMap].Value(); + outputShape = TimesOpOutputShape(inputs[0], inputs[1], outputRank, inferInputRankToMap, inferDimensions); + break; + } + case PrimitiveOpType::TransposeTimes: + { + assert(inputs.size() == 2); + + auto transposeShapeFunc = [](const NDShape& shape) { + NDShape transposedShape(std::max(2, shape.Rank()), 1); + for (size_t i = 0; i < shape.Rank(); ++i) + transposedShape[transposedShape.Rank() - i - 1] = shape[i]; + + return transposedShape; + }; + + if (inputs[0].Shape().Rank() > 2) + LogicError("TransposeTimes operation currently only supports %s operands of rank 1 or 2", Internal::IsReversingTensorShapesInErrorMessagesEnabled() ? "right" : "left"); + + NDShape transposedLeftOperandShape = transposeShapeFunc(inputs[0].Shape()); + Variable dummyLeftOperand = PlaceholderVariable(transposedLeftOperandShape); + size_t outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value(); + outputShape = TimesOpOutputShape(dummyLeftOperand, inputs[1], outputRank, -1, inferDimensions); + if (dummyLeftOperand.Shape() != transposedLeftOperandShape) + inputs[0].m_dataFields->m_shape = transposeShapeFunc(dummyLeftOperand.Shape()); + + break; + } + case PrimitiveOpType::Convolution: + { + assert(inputs.size() == 2); + auto& strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value(); + auto& lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value(); + auto& upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value(); + auto sharing = AsVector(functionConfig[PrimitiveFunction::AttributeNameSharing].Value>()); + auto autoPadding = AsVector(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value>()); + bool transpose = functionConfig[PrimitiveFunction::AttributeNameTranspose].Value(); + if (inputs[0].Shape().Rank() < inputs[1].Shape().Rank()) + InvalidArgument("The convolution map should have at least as many axes as the shape of the input it operates on!"); + + NDShape outputMapCount, kernelShape; + std::tie(outputMapCount, kernelShape) = GetConvolutionOutputMapCountAndKernelShape(inputs[0].Shape(), inputs[1].Shape()); + auto originalKernelShape = kernelShape; + outputShape = ConvolutionOpOutputShape(op, inputs[1].Shape(), kernelShape, outputMapCount, strides, sharing, autoPadding, lowerPad, upperPad, transpose, inferDimensions); + if (originalKernelShape != kernelShape) + { + for (size_t i = 0; i < kernelShape.Rank(); ++i) + inputs[0].m_dataFields->m_shape[i] = kernelShape[i]; + } + + functionConfig[PrimitiveFunction::AttributeNameSharing] = AsDictionaryValueVector(sharing); + functionConfig[PrimitiveFunction::AttributeNameAutoPadding] = AsDictionaryValueVector(autoPadding); + break; + } + case PrimitiveOpType::Logistic: + case PrimitiveOpType::SquaredError: + case PrimitiveOpType::CrossEntropyWithSoftmax: + case PrimitiveOpType::ClassificationError: + { + if ((op == PrimitiveOpType::ClassificationError) || (op == PrimitiveOpType::Logistic)) + assert(inputs.size() >= 2); + else + assert(inputs.size() == 2); + + if ((inputs[0].Shape().Rank() > 2) || ((inputs[0].Shape().Rank() > 1) && (inputs[0].Shape()[1] != 1))) + InvalidArgument("The shape of input operands for the %S operation should have at most one axis", PrimitiveOpTypeName(op).c_str()); + + auto predictionShape = inputs[0].Shape(); + auto labelsShape = inputs[1].Shape(); + if (predictionShape != labelsShape) + RuntimeError("Prediction output operand's shape %S is incompatible with label operand's shape %S for the %S operation", AsStringForErrorReporting(predictionShape).c_str(), AsStringForErrorReporting(labelsShape).c_str(), PrimitiveOpTypeName(op).c_str()); + + std::vector reductionAxes; + for (int i = 0; i < (int)inputs[0].Shape().Rank(); ++i) + reductionAxes.push_back(i); + + outputShape = ReductionOpOutputShape(op, predictionShape, reductionAxes, /*preserveReductionAxes =*/ false); + break; + } + case PrimitiveOpType::ReduceElements: + { + assert(inputs.size() == 1); + auto reductionAxis = NormalizeStaticAxis(functionConfig[PrimitiveFunction::AttributeNameAxis].Value(), inputs[0].Shape()); + if (reductionAxis == Axis::AllStaticAxes()) + outputShape = {}; + else + { + std::vector reductionAxes = { reductionAxis.StaticAxisIndex() }; + outputShape = ReductionOpOutputShape(op, inputs[0].Shape(), reductionAxes, /*preserveReductionAxes =*/ true); + } + break; + } + case PrimitiveOpType::BatchNormalization: + { + assert(inputs.size() == 5); + auto spatial = functionConfig[PrimitiveFunction::AttributeNameSpatial].Value(); + outputShape = BatchNormalizationOutputShape(inputs, spatial, inferDimensions); + break; + } + case PrimitiveOpType::PackedIndex: + outputShape = UnaryElementwiseOpOutputShape(inputs[1].Shape()); + break; + case PrimitiveOpType::GatherPacked: + { + bool sourceHasDynamicAxis = !inputs[0].DynamicAxes().empty(); + + // inherit tensor dimension from sourceData, minus the last (column or time) dimension. TODO this needs to become simpler... + if (sourceHasDynamicAxis) + outputShape = inputs[0].Shape(); + else + { + if (inputs[0].Shape().Rank() > 1) + outputShape = outputShape.SubShape(0, outputShape.Rank() - 1); + else + outputShape = {}; + } + + break; + } + case PrimitiveOpType::ScatterPacked: + { + if (inputs[0].DynamicAxes().empty() || inputs[1].DynamicAxes().empty() || inputs[2].DynamicAxes().empty()) + InvalidArgument("ScatterPacked requires all its operands to have dynamic axes"); + + outputShape = inputs[0].Shape(); + break; + } + case PrimitiveOpType::Clip: + assert(inputs.size() == 3); + outputShape = UnaryElementwiseOpOutputShape(inputs[0].Shape()); + break; + case PrimitiveOpType::Select: + assert(inputs.size() == 3); + outputShape = NaryElementwiseOpOutputShape(op, inputs, true); + break; + case PrimitiveOpType::Splice: + { + assert(inputs.size() >= 2); + auto maxInputRank = MaxInputRank(inputs); + auto spliceAxis = NormalizeStaticAxis(functionConfig[PrimitiveFunction::AttributeNameAxis].Value(), NDShape(maxInputRank)); + + if (!spliceAxis.IsStaticAxis()) + LogicError("Splice operation currently does not support splicing along dynamic axis"); + + if (spliceAxis.StaticAxisIndex() < 0) + InvalidArgument("Splice: The axis argument's static axis index must be >= 0!"); + + outputShape = SpliceOutputShape(inputs, spliceAxis.StaticAxisIndex()); + break; + } + case PrimitiveOpType::RandomSample: + case PrimitiveOpType::RandomSampleInclusionFrequency: + { + auto numSamples = functionConfig[PrimitiveFunction::AttributeNameNumSamples].Value(); + auto allowDuplicates = functionConfig[PrimitiveFunction::AttributeNameAllowDuplicates].Value(); + + if (numSamples == 0) + InvalidArgument("Number of requested samples is zero."); + + let& shape = inputs[0].Shape(); + size_t numClasses = shape.Dimensions()[0]; + + if (numClasses != NDShape::InferredDimension && !allowDuplicates && numClasses <= numSamples) + InvalidArgument("For sampling without duplicates the number of requested samples (%lu) needs to be less than the number of classes (%lu).", numSamples, numClasses); + + // within this block we handle RandomSample and RandomSampleInclusionFrequency + if (op == PrimitiveOpType::RandomSampleInclusionFrequency) + outputShape = shape; + else + { + vector dimensions{ numSamples, numClasses }; + outputShape = NDShape(dimensions); + } + + break; + } + case PrimitiveOpType::OptimizedRNNStack: + { + assert(inputs.size() == 2); + auto operand = inputs[0]; + auto parameter = inputs[1]; + if (operand.Shape().Rank() != 1) + InvalidArgument("OptimizedRNNStack: input must have rank 1; actual input rank is %lu", operand.Shape().Rank()); + if (operand.DynamicAxes().empty()) + InvalidArgument("OptimizedRNNStack: input must have at least one dynamic axis"); + auto numLayers = functionConfig[PrimitiveFunction::AttributeNameNumLayers].Value(); + if (numLayers == 0) + InvalidArgument("Number of layers in OptimizedRNNStack operation should be positive"); + auto bidirectional = functionConfig[PrimitiveFunction::AttributeNameBidirectional].Value(); + auto hiddenSize = functionConfig[PrimitiveFunction::AttributeNameHiddenSize].Value(); + + // output dims + outputShape = operand.Shape(); + outputShape[0] = (bidirectional ? 2 : 1) * hiddenSize; + // infer input size + // Note: Output dim is second axis, so say initOutputRank=-1. + if (parameter.Shape().Rank() == 2) + { + const auto recurrentOp = functionConfig[PrimitiveFunction::AttributeNameRecurrentOp].Value(); + const auto attributes = RnnAttributes(bidirectional, numLayers, hiddenSize, recurrentOp, -1); + const auto numParameters = attributes.GetNumParameters(operand.Shape().TotalSize()); + std::vector> newOperandShapes = { { parameter, std::move(NDShape({ numParameters.first, numParameters.second })) } }; + UpdateOperandShapes(newOperandShapes); + } + break; + } + case PrimitiveOpType::ReconcileDynamicAxis: + { + assert(inputs.size() == 2); + auto operand = inputs[0]; + auto layout = inputs[1]; + if (operand.DynamicAxes().empty()) + InvalidArgument("ReconcileDynamicAxis: input must have at least one dynamic axis"); + if (layout.DynamicAxes().empty()) + InvalidArgument("ReconcileDynamicAxis: layout must have at least one dynamic axis"); + outputShape = operand.Shape(); + break; + } + default: + LogicError("Specified op %S not yet supported", PrimitiveOpTypeName(op).c_str()); + break; + } + } + + return{ OutputVariable(outputShape, outputDataType, owner, outputDynamicAxes, functionName.empty() ? L"" : functionName + L"_output") }; + } + + static const std::wstring s_primitiveFunctionTypeValue = L"PrimitiveFunction"; + + /*virtual*/ Dictionary PrimitiveFunction::Serialize() const + { + Dictionary dict; + + dict[versionKey] = CurrentVersion(); + dict[typeKey] = s_primitiveFunctionTypeValue; + dict[opKey] = static_cast(m_op); + dict[attributesKey] = Attributes(); + dict[uidKey] = Uid(); + dict[nameKey] = Name(); + + auto inputs = Inputs(); + vector inputUids; + inputUids.reserve(inputs.size()); + for (auto& input : inputs) + { + inputUids.push_back(input.Uid()); + } + + dict[inputsKey] = std::move(inputUids); + + return dict; + } + + /*static*/ FunctionPtr PrimitiveFunction::Deserialize(const Dictionary& dict, + const std::unordered_map& uidToVariableMap, + const CNTK::DeviceDescriptor& device) + { + static const vector s_requiredDictionaryKeys = { typeKey, opKey, uidKey, attributesKey, inputsKey, nameKey }; + + size_t version = ValidateDictionary(dict, s_requiredDictionaryKeys, s_primitiveFunctionTypeValue, s_serializationVersion); + + PrimitiveOpType op = PrimitiveOpType(dict[opKey].Value()); + + // The hard requirement that the serialization depends on is that + // new op type values are only added to the end of the list, after Combine. + // This also applies to other enums (DataType, VariableKind, etc.) + if (op > PrimitiveOpType::Combine) + { + LogicError("Unexpected variable '%ls':'%u' (%s).", + opKey.c_str(), + static_cast::type>(op), + GetVersionsString(s_serializationVersion, version).c_str()); + } + + const auto& uid = dict[uidKey].Value(); + const auto& name = dict[nameKey].Value(); + auto attributes = dict[attributesKey].Value(); + const auto& inputUids = dict[inputsKey].Value>(); + + std::vector inputs; + inputs.reserve(inputUids.size()); + + for (const auto& dictionaryValue : inputUids) + { + const auto& inputUid = dictionaryValue.Value(); + if (uidToVariableMap.find(inputUid) == uidToVariableMap.end()) + { + LogicError("There are no inputs corresponging to input uid = '%ls' " + "(%s).", inputUid.c_str(), GetVersionsString(s_serializationVersion, version).c_str()); + } + inputs.push_back(uidToVariableMap.at(inputUid)); + } + + return std::shared_ptr(new PrimitiveFunction(op, inputs, std::move(attributes), name, uid), + [](PrimitiveFunction* ptr) { delete ptr; }); + } +} diff --git a/Source/CNTKv2LibraryDll/Function.h b/Source/CNTKv2LibraryDll/PrimitiveFunction.h similarity index 70% rename from Source/CNTKv2LibraryDll/Function.h rename to Source/CNTKv2LibraryDll/PrimitiveFunction.h index 4c1f0fe4c..8742a8439 100644 --- a/Source/CNTKv2LibraryDll/Function.h +++ b/Source/CNTKv2LibraryDll/PrimitiveFunction.h @@ -8,12 +8,9 @@ #include "stdafx.h" #include "CNTKLibrary.h" #include "PrimitiveOpType.h" -#include -#include "ComputationNetwork.h" #include "Utils.h" #include "ConvolveGeometry.h" #include "ConvolutionalNodes.h" -#include "BackCompat.h" namespace std { @@ -648,256 +645,4 @@ namespace CNTK // a more meaningful message when trying to load a new model with a stale binary. static const size_t s_serializationVersion = 2; }; - - class CNTKBackPropState final : public BackPropState - { - public: - CNTKBackPropState(const FunctionPtr& function, const DeviceDescriptor& computeDevice, const std::pair& evalTimeStamp) - : BackPropState(function, computeDevice), m_evalTimeStamp(evalTimeStamp) - {} - - std::pair EvalTimeStamp() const - { - return m_evalTimeStamp; - } - - private: - std::pair m_evalTimeStamp; - }; - typedef std::shared_ptr CNTKBackPropStatePtr; - - class CompositeFunction; - typedef std::shared_ptr CompositeFunctionPtr; - - class CompositeFunction final : public Function - { - friend class Function; - friend class Trainer; - friend class CompositeMinibatchSource; - friend class PackedValue; - - template - friend inline std::shared_ptr MakeSharedObject(CtorArgTypes&& ...ctorArgs); - - friend void Internal::SaveAsLegacyModel(const FunctionPtr& rootFunction, const std::wstring& modelFile); - - friend void ComputeInputPerDimMeansAndInvStdDevs(const MinibatchSourcePtr& minibatchSource, - std::unordered_map>& computedMeanAndInvStdDevs, - const DeviceDescriptor& device /*= DeviceDescriptor::CPUDevice()*/); - - static std::atomic s_nextAutoGeneratedDynamicAxis; - - static const std::wstring CompositeFunctionOpName; - - public: - static const std::wstring InternalDefaultDynamicAxisName; - static const std::wstring InternalNoSequenceAxisName; - - static Axis NextAutoGeneratedDynamicAxis() - { - static const std::wstring s_autoGeneratedDynamicAxisNamePrefix = L"autoGeneratedDynamicAxis_"; - return Axis(s_autoGeneratedDynamicAxisNamePrefix + std::to_wstring(s_nextAutoGeneratedDynamicAxis++)); - } - - public: - static CompositeFunctionPtr Create(const FunctionPtr& rootFunction, const std::wstring& name = L"", const std::wstring& uid = L"") - { - std::unordered_set visitedFunctions; - - // Call Collect to get the set of all functions in the graph - Collect(rootFunction, visitedFunctions); - - return MakeSharedObject(rootFunction, std::move(visitedFunctions), name, uid); - } - - virtual BackPropStatePtr Forward(const std::unordered_map& arguments, - std::unordered_map& outputs, - const DeviceDescriptor& computeDevice, - const std::unordered_set& outputsToRetainBackwardStateFor) override; - - virtual void Backward(const BackPropStatePtr& state, - const std::unordered_map& rootGradientValues, - std::unordered_map& backPropagatedGradientValuesForInputs) override; - - virtual Dictionary Serialize() const override; - - virtual size_t CurrentVersion() const override { return s_serializationVersion; } - - static FunctionPtr Deserialize(const Dictionary& dictionary, const CNTK::DeviceDescriptor& device); - - virtual const std::wstring& OpName() override - { - return CompositeFunctionOpName; - } - - template - static void Traverse(const FunctionPtr& rootFunction, const FunctionType& functor) - { - std::unordered_set visitedFunctions; - Traverse(rootFunction, visitedFunctions, functor); - } - - // Recursively traverses the Function graph underlying the 'rootFunction' invoking the provided functor for all visited nodes in the graph. - template - static void Traverse(const FunctionPtr& rootFunction, std::unordered_set& visitedFunctions, const FunctionType& functor) - { - visitedFunctions.insert(rootFunction); - functor(rootFunction); - - std::vector rootFunctionInputs = rootFunction->Inputs(); - for (const auto& rootInput : rootFunctionInputs) - { - if (rootInput.IsOutput() && visitedFunctions.find(rootInput.Owner()) == visitedFunctions.end()) - { - const auto& function = rootInput.Owner(); - Traverse(function, visitedFunctions, functor); - } - } - } - - private: - virtual void ReplacePlaceholdersInPlace(const std::unordered_map& placeholderReplacements, - std::unordered_set& visitedFunctions, - std::unordered_set& replacedPlaceholders) override; - - CompositeFunction(const FunctionPtr& rootFunction, std::unordered_set&& allPrimitiveFunctions, const std::wstring& name, const std::wstring& uid = Internal::GenerateUid(L"CompositeFunction")) - : Function({}, rootFunction->Outputs(), Dictionary(), rootFunction, name, uid), - m_allPrimitiveFunctions(std::move(allPrimitiveFunctions)), m_networkMatricesAllocated(false) - {} - - std::vector DetermineInputs() const - { - const auto& root = RootFunction(); - std::unordered_set visitedFunctions; - return DetermineInputs(root, visitedFunctions); - } - - // Recursively traverses the Function graph and populates the provided set of functions. - static void Collect(const FunctionPtr& rootFunction, std::unordered_set& functions) - { - // Call Traverse to get the set of all functions in the graph - Traverse(rootFunction, functions, [](const FunctionPtr& f){}); - } - - // Recursively traverses the Function graph underlying the 'rootFunction' to determine all the leaves (aka inputs) of the graph - static std::vector DetermineInputs(const FunctionPtr& rootFunction, std::unordered_set& visitedFunctions) - { - vector functions; - std::vector inputs; - std::unordered_set uniqueInputs; - Traverse(rootFunction, visitedFunctions, [&inputs, &uniqueInputs](const FunctionPtr& f){ - std::vector functionInputs = f->Inputs(); - for (auto input : functionInputs) - { - if (!input.IsOutput() && uniqueInputs.find(input) == uniqueInputs.end()) - { - inputs.push_back(input); - uniqueInputs.insert(input); - } - } - }); - - return inputs; - } - - template - Microsoft::MSR::CNTK::ComputationNetworkPtr GetComputationNetwork(const DeviceDescriptor& device, const std::unordered_set& backpropRoots, bool allocateNetworkMatrices); - - template - static Microsoft::MSR::CNTK::ComputationNodeBasePtr CreateComputationNode(const Variable& variable, - PrimitiveFunction* primitiveFunction, - const std::vector>>& inputNodes, - Microsoft::MSR::CNTK::ComputationNetworkPtr& network, - std::unordered_map& variableToNodeMap); - - template - static Microsoft::MSR::CNTK::ComputationNodeBasePtr GetOutputVariableNode(const Variable& variable, - Microsoft::MSR::CNTK::ComputationNetworkPtr& network, - Microsoft::MSR::CNTK::ComputationNetworkBuilder& builder, - std::unordered_map& variableToNodeMap, - std::unordered_map& isVariableRootMap); - - template - static Microsoft::MSR::CNTK::ComputationNodeBasePtr GetNode(const Variable& variable, Microsoft::MSR::CNTK::ComputationNetworkPtr& network, - Microsoft::MSR::CNTK::ComputationNetworkBuilder& builder, - std::unordered_map& variableToNodeMap, - std::unordered_map& isVariableRootMap); - - template - static void PopulateComputationNodeValue(const std::pair& variableValue, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode); - void PopulateNetworkInputs(const std::unordered_map& arguments); - - template - static void PopulateComputationNodeGradient(const std::pair& variableGradient, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode); - void PopulateNetworkGradients(const std::unordered_map& gradients); - - static void GetNodeOutputOrGradient(Variable var, ValuePtr& varValue, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode, bool getGradient); - void GetNetworkOutputs(std::unordered_map& outputs); - void GetNetworkGradients(std::unordered_map& gradients); - - template - static std::pair>, Microsoft::MSR::CNTK::MBLayoutPtr> GetCNTKImplMatrixAndMBLayoutFromValueObject(Variable var, const ValuePtr& value); - - template - static ValuePtr GetValueObjectFromCNTKImplMatrixAndMBLayout(const NDShape& sampleShape, const Microsoft::MSR::CNTK::Matrix& matrix, const Microsoft::MSR::CNTK::MBLayoutPtr& layout, bool readOnly = true); - template - static ValuePtr GetValueObjectFromCNTKImplMatrixAndMBLayout(Variable var, const Microsoft::MSR::CNTK::Matrix& matrix, const Microsoft::MSR::CNTK::MBLayoutPtr& layout, bool readOnly = true); - - const std::vector& GetArgumentDependencies(const Variable& output); - - private: - - // Set of all primitive functions in the graph underlying 'this' Function. Also keeps the primitive Function objects alive - // by holding strong references to them - std::unordered_set m_allPrimitiveFunctions; - - // A map from Variable objects to ComputationNode objects in the ComputationNetwork instance that implements 'this' Composite Function - std::unordered_map m_variableToNodeMap; - - // A map that tells whether a Variable in the graph underlying 'this' Function is a root of the graph - std::unordered_map m_isVariableRootMap; - - Microsoft::MSR::CNTK::ComputationNetworkPtr m_computationNetwork; - - // The backpropRoots sepecified in the most recent 'Forward' call on 'this' Function. - // This indicates for which of its roots has 'this' Function retained required intermediate - // states from the previos Forward call to be able to backpropagate gradients backwards from in - // the next 'Backward' call. - std::unordered_set m_currentBackpropRoots; - - std::unordered_map> m_perOutputVarArgumentDependencies; - - bool m_networkMatricesAllocated; - - std::unordered_map m_lastRecordedParameterValueTimeStamps; - - static const size_t s_serializationVersion = 1; - }; - - inline std::vector DynamicAxesFromInternalDynamicAxisName(const std::wstring& internalDynamicAxisName) - { - std::vector inputVarDynamicAxes; - if (internalDynamicAxisName.substr(0, CNTK::CompositeFunction::InternalDefaultDynamicAxisName.length()) == CNTK::CompositeFunction::InternalDefaultDynamicAxisName) - inputVarDynamicAxes = { CNTK::Axis::DefaultDynamicAxis(), CNTK::Axis::DefaultBatchAxis() }; - else if (internalDynamicAxisName.substr(0, CNTK::CompositeFunction::InternalNoSequenceAxisName.length()) == CNTK::CompositeFunction::InternalNoSequenceAxisName) - inputVarDynamicAxes = { CNTK::Axis::DefaultBatchAxis() }; - else - inputVarDynamicAxes = { CNTK::Axis(internalDynamicAxisName), CNTK::Axis::DefaultBatchAxis() }; - - return inputVarDynamicAxes; - } - - // Construct the dynamic axis name to be used internally for the CNTK InputNodes - inline std::wstring InternalDynamicAxisNameFromDynamicAxes(const std::vector& dynamicAxes) - { - if (dynamicAxes.empty()) - LogicError("Empty dynamic axes set"); - - if (dynamicAxes == std::vector({ Axis::DefaultBatchAxis() })) - return CompositeFunction::InternalNoSequenceAxisName; - else if (dynamicAxes == std::vector({ Axis::DefaultDynamicAxis(), Axis::DefaultBatchAxis() })) - return CompositeFunction::InternalDefaultDynamicAxisName; - else - return dynamicAxes[0].Name(); - } } diff --git a/Source/CNTKv2LibraryDll/Trainer.cpp b/Source/CNTKv2LibraryDll/Trainer.cpp index 2c2bf20aa..98e26b420 100644 --- a/Source/CNTKv2LibraryDll/Trainer.cpp +++ b/Source/CNTKv2LibraryDll/Trainer.cpp @@ -6,7 +6,6 @@ #include "stdafx.h" #include "CNTKLibrary.h" #include "Utils.h" -#include "Function.h" #include "Serialization.h" namespace diff --git a/Source/CNTKv2LibraryDll/Utils.cpp b/Source/CNTKv2LibraryDll/Utils.cpp index dc13cdff1..a0323aa29 100644 --- a/Source/CNTKv2LibraryDll/Utils.cpp +++ b/Source/CNTKv2LibraryDll/Utils.cpp @@ -13,8 +13,8 @@ #include "CNTKLibrary.h" #include "Utils.h" #include "Serialization.h" -#include "Function.h" #include +#include "PrimitiveFunction.h" using namespace std; diff --git a/Source/CNTKv2LibraryDll/Value.cpp b/Source/CNTKv2LibraryDll/Value.cpp index 899814e68..cd2d7965e 100644 --- a/Source/CNTKv2LibraryDll/Value.cpp +++ b/Source/CNTKv2LibraryDll/Value.cpp @@ -10,9 +10,9 @@ #endif #include "CNTKLibrary.h" +#include "CompositeFunction.h" #include "Utils.h" #include "Value.h" -#include "Function.h" namespace CNTK { diff --git a/Source/CNTKv2LibraryDll/Value.h b/Source/CNTKv2LibraryDll/Value.h index 085f700a6..dfb0f6408 100644 --- a/Source/CNTKv2LibraryDll/Value.h +++ b/Source/CNTKv2LibraryDll/Value.h @@ -8,6 +8,7 @@ #include "stdafx.h" #include "CNTKLibrary.h" #include "Sequences.h" +#include "TensorView.h" #include "Utils.h" namespace CNTK diff --git a/Source/CNTKv2LibraryDll/Variable.cpp b/Source/CNTKv2LibraryDll/Variable.cpp index 953fdb42a..f9066789a 100644 --- a/Source/CNTKv2LibraryDll/Variable.cpp +++ b/Source/CNTKv2LibraryDll/Variable.cpp @@ -5,8 +5,8 @@ #include "stdafx.h" #include "CNTKLibrary.h" +#include "CompositeFunction.h" #include "Serialization.h" -#include "Function.h" #include "InputAndParamNodes.h" namespace CNTK