CNTK v2 library: Refactor Function.h and .cpp into multiple files
This commit is contained in:
Родитель
2b8b3047df
Коммит
3723b405b0
2
Makefile
2
Makefile
|
@ -409,6 +409,8 @@ CNTKLIBRARY_COMMON_SRC =\
|
|||
$(SOURCEDIR)/CNTKv2LibraryDll/BackCompat.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/Common.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/Function.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/PrimitiveFunction.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/CompositeFunction.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/NDArrayView.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/NDMask.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/Trainer.cpp \
|
||||
|
|
|
@ -6,7 +6,8 @@
|
|||
#include "stdafx.h"
|
||||
#include "CNTKLibrary.h"
|
||||
#include "BackCompat.h"
|
||||
#include "Function.h"
|
||||
#include "PrimitiveFunction.h"
|
||||
#include "CompositeFunction.h"
|
||||
#include "ComputationNetworkBuilder.h"
|
||||
#include "Utils.h"
|
||||
#include "ComputationNode.h"
|
||||
|
|
|
@ -135,12 +135,13 @@
|
|||
<ClInclude Include="API\CNTKLibraryExperimental.h" />
|
||||
<ClInclude Include="API\CNTKLibraryInternals.h" />
|
||||
<ClInclude Include="BackCompat.h" />
|
||||
<ClInclude Include="CompositeFunction.h" />
|
||||
<ClInclude Include="DataParallelDistributedTrainer.h" />
|
||||
<ClInclude Include="DistributedCommunicator.h" />
|
||||
<ClInclude Include="DistributedTrainerBase.h" />
|
||||
<ClInclude Include="Function.h" />
|
||||
<ClInclude Include="Learner.h" />
|
||||
<ClInclude Include="MinibatchSource.h" />
|
||||
<ClInclude Include="PrimitiveFunction.h" />
|
||||
<ClInclude Include="PrimitiveOpType.h" />
|
||||
<ClInclude Include="Serialization.h" />
|
||||
<ClInclude Include="Utils.h" />
|
||||
|
@ -151,6 +152,7 @@
|
|||
<ItemGroup>
|
||||
<ClCompile Include="BackCompat.cpp" />
|
||||
<ClCompile Include="Common.cpp" />
|
||||
<ClCompile Include="CompositeFunction.cpp" />
|
||||
<ClCompile Include="ComputeInputStatistics.cpp" />
|
||||
<ClCompile Include="DataParallelDistributedTrainer.cpp" />
|
||||
<ClCompile Include="DistributedCommunicator.cpp" />
|
||||
|
@ -165,6 +167,7 @@
|
|||
<ClCompile Include="MinibatchSource.cpp" />
|
||||
<ClCompile Include="NDArrayView.cpp" />
|
||||
<ClCompile Include="NDMask.cpp" />
|
||||
<ClCompile Include="PrimitiveFunction.cpp" />
|
||||
<ClCompile Include="proto\CNTK.pb.cc.VS_wrapper.cpp" />
|
||||
<ClCompile Include="Serialization.cpp" />
|
||||
<ClCompile Include="stdafx.cpp">
|
||||
|
|
|
@ -22,6 +22,8 @@
|
|||
<ClCompile Include="DistributedCommunicator.cpp" />
|
||||
<ClCompile Include="DataParallelDistributedTrainer.cpp" />
|
||||
<ClCompile Include="DistributedTrainerBase.cpp" />
|
||||
<ClCompile Include="CompositeFunction.cpp" />
|
||||
<ClCompile Include="PrimitiveFunction.cpp" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="stdafx.h" />
|
||||
|
@ -33,7 +35,6 @@
|
|||
<ClInclude Include="API\CNTKLibraryInternals.h">
|
||||
<Filter>API</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="Function.h" />
|
||||
<ClInclude Include="Learner.h" />
|
||||
<ClInclude Include="MinibatchSource.h" />
|
||||
<ClInclude Include="API\CNTKLibraryExperimental.h">
|
||||
|
@ -46,6 +47,8 @@
|
|||
<ClInclude Include="DataParallelDistributedTrainer.h" />
|
||||
<ClInclude Include="DistributedTrainerBase.h" />
|
||||
<ClInclude Include="BackCompat.h" />
|
||||
<ClInclude Include="CompositeFunction.h" />
|
||||
<ClInclude Include="PrimitiveFunction.h" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<Filter Include="API">
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,267 @@
|
|||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "stdafx.h"
|
||||
#include "CNTKLibrary.h"
|
||||
#include "PrimitiveFunction.h"
|
||||
#include "ComputationNetwork.h"
|
||||
#include "BackCompat.h"
|
||||
|
||||
namespace CNTK
|
||||
{
|
||||
class CNTKBackPropState final : public BackPropState
|
||||
{
|
||||
public:
|
||||
CNTKBackPropState(const FunctionPtr& function, const DeviceDescriptor& computeDevice, const std::pair<Variable, int64_t>& evalTimeStamp)
|
||||
: BackPropState(function, computeDevice), m_evalTimeStamp(evalTimeStamp)
|
||||
{}
|
||||
|
||||
std::pair<Variable, int64_t> EvalTimeStamp() const
|
||||
{
|
||||
return m_evalTimeStamp;
|
||||
}
|
||||
|
||||
private:
|
||||
std::pair<Variable, int64_t> m_evalTimeStamp;
|
||||
};
|
||||
typedef std::shared_ptr<CNTKBackPropState> CNTKBackPropStatePtr;
|
||||
|
||||
class CompositeFunction;
|
||||
typedef std::shared_ptr<CompositeFunction> CompositeFunctionPtr;
|
||||
|
||||
class CompositeFunction final : public Function
|
||||
{
|
||||
friend class Function;
|
||||
friend class Trainer;
|
||||
friend class CompositeMinibatchSource;
|
||||
friend class PackedValue;
|
||||
|
||||
template <typename T, typename ...CtorArgTypes>
|
||||
friend inline std::shared_ptr<T> MakeSharedObject(CtorArgTypes&& ...ctorArgs);
|
||||
|
||||
friend void Internal::SaveAsLegacyModel(const FunctionPtr& rootFunction, const std::wstring& modelFile);
|
||||
|
||||
friend void ComputeInputPerDimMeansAndInvStdDevs(const MinibatchSourcePtr& minibatchSource,
|
||||
std::unordered_map<StreamInformation, std::pair<NDArrayViewPtr, NDArrayViewPtr>>& computedMeanAndInvStdDevs,
|
||||
const DeviceDescriptor& device /*= DeviceDescriptor::CPUDevice()*/);
|
||||
|
||||
static std::atomic<unsigned int> s_nextAutoGeneratedDynamicAxis;
|
||||
|
||||
static const std::wstring CompositeFunctionOpName;
|
||||
|
||||
public:
|
||||
static const std::wstring InternalDefaultDynamicAxisName;
|
||||
static const std::wstring InternalNoSequenceAxisName;
|
||||
|
||||
static Axis NextAutoGeneratedDynamicAxis()
|
||||
{
|
||||
static const std::wstring s_autoGeneratedDynamicAxisNamePrefix = L"autoGeneratedDynamicAxis_";
|
||||
return Axis(s_autoGeneratedDynamicAxisNamePrefix + std::to_wstring(s_nextAutoGeneratedDynamicAxis++));
|
||||
}
|
||||
|
||||
public:
|
||||
static CompositeFunctionPtr Create(const FunctionPtr& rootFunction, const std::wstring& name = L"", const std::wstring& uid = L"")
|
||||
{
|
||||
std::unordered_set<FunctionPtr> visitedFunctions;
|
||||
|
||||
// Call Collect to get the set of all functions in the graph
|
||||
Collect(rootFunction, visitedFunctions);
|
||||
|
||||
return MakeSharedObject<CompositeFunction>(rootFunction, std::move(visitedFunctions), name, uid);
|
||||
}
|
||||
|
||||
virtual BackPropStatePtr Forward(const std::unordered_map<Variable, ValuePtr>& arguments,
|
||||
std::unordered_map<Variable, ValuePtr>& outputs,
|
||||
const DeviceDescriptor& computeDevice,
|
||||
const std::unordered_set<Variable>& outputsToRetainBackwardStateFor) override;
|
||||
|
||||
virtual void Backward(const BackPropStatePtr& state,
|
||||
const std::unordered_map<Variable, ValuePtr>& rootGradientValues,
|
||||
std::unordered_map<Variable, ValuePtr>& backPropagatedGradientValuesForInputs) override;
|
||||
|
||||
virtual Dictionary Serialize() const override;
|
||||
|
||||
virtual size_t CurrentVersion() const override { return s_serializationVersion; }
|
||||
|
||||
static FunctionPtr Deserialize(const Dictionary& dictionary, const CNTK::DeviceDescriptor& device);
|
||||
|
||||
virtual const std::wstring& OpName() override
|
||||
{
|
||||
return CompositeFunctionOpName;
|
||||
}
|
||||
|
||||
template <typename FunctionType>
|
||||
static void Traverse(const FunctionPtr& rootFunction, const FunctionType& functor)
|
||||
{
|
||||
std::unordered_set<FunctionPtr> visitedFunctions;
|
||||
Traverse(rootFunction, visitedFunctions, functor);
|
||||
}
|
||||
|
||||
// Recursively traverses the Function graph underlying the 'rootFunction' invoking the provided functor for all visited nodes in the graph.
|
||||
template <typename FunctionType>
|
||||
static void Traverse(const FunctionPtr& rootFunction, std::unordered_set<FunctionPtr>& visitedFunctions, const FunctionType& functor)
|
||||
{
|
||||
visitedFunctions.insert(rootFunction);
|
||||
functor(rootFunction);
|
||||
|
||||
std::vector<Variable> rootFunctionInputs = rootFunction->Inputs();
|
||||
for (const auto& rootInput : rootFunctionInputs)
|
||||
{
|
||||
if (rootInput.IsOutput() && visitedFunctions.find(rootInput.Owner()) == visitedFunctions.end())
|
||||
{
|
||||
const auto& function = rootInput.Owner();
|
||||
Traverse(function, visitedFunctions, functor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
virtual void ReplacePlaceholdersInPlace(const std::unordered_map<Variable, Variable>& placeholderReplacements,
|
||||
std::unordered_set<const Function*>& visitedFunctions,
|
||||
std::unordered_set<Variable>& replacedPlaceholders) override;
|
||||
|
||||
CompositeFunction(const FunctionPtr& rootFunction, std::unordered_set<FunctionPtr>&& allPrimitiveFunctions, const std::wstring& name, const std::wstring& uid = Internal::GenerateUid(L"CompositeFunction"))
|
||||
: Function({}, rootFunction->Outputs(), Dictionary(), rootFunction, name, uid),
|
||||
m_allPrimitiveFunctions(std::move(allPrimitiveFunctions)), m_networkMatricesAllocated(false)
|
||||
{}
|
||||
|
||||
std::vector<Variable> DetermineInputs() const
|
||||
{
|
||||
const auto& root = RootFunction();
|
||||
std::unordered_set<FunctionPtr> visitedFunctions;
|
||||
return DetermineInputs(root, visitedFunctions);
|
||||
}
|
||||
|
||||
// Recursively traverses the Function graph and populates the provided set of functions.
|
||||
static void Collect(const FunctionPtr& rootFunction, std::unordered_set<FunctionPtr>& functions)
|
||||
{
|
||||
// Call Traverse to get the set of all functions in the graph
|
||||
Traverse(rootFunction, functions, [](const FunctionPtr& f){});
|
||||
}
|
||||
|
||||
// Recursively traverses the Function graph underlying the 'rootFunction' to determine all the leaves (aka inputs) of the graph
|
||||
static std::vector<Variable> DetermineInputs(const FunctionPtr& rootFunction, std::unordered_set<FunctionPtr>& visitedFunctions)
|
||||
{
|
||||
vector<FunctionPtr> functions;
|
||||
std::vector<Variable> inputs;
|
||||
std::unordered_set<Variable> uniqueInputs;
|
||||
Traverse(rootFunction, visitedFunctions, [&inputs, &uniqueInputs](const FunctionPtr& f){
|
||||
std::vector<Variable> functionInputs = f->Inputs();
|
||||
for (auto input : functionInputs)
|
||||
{
|
||||
if (!input.IsOutput() && uniqueInputs.find(input) == uniqueInputs.end())
|
||||
{
|
||||
inputs.push_back(input);
|
||||
uniqueInputs.insert(input);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
return inputs;
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
Microsoft::MSR::CNTK::ComputationNetworkPtr GetComputationNetwork(const DeviceDescriptor& device, const std::unordered_set<Variable>& backpropRoots, bool allocateNetworkMatrices);
|
||||
|
||||
template <typename ElementType>
|
||||
static Microsoft::MSR::CNTK::ComputationNodeBasePtr CreateComputationNode(const Variable& variable,
|
||||
PrimitiveFunction* primitiveFunction,
|
||||
const std::vector<std::shared_ptr<Microsoft::MSR::CNTK::ComputationNode<ElementType>>>& inputNodes,
|
||||
Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
|
||||
std::unordered_map<Variable, Microsoft::MSR::CNTK::ComputationNodeBasePtr>& variableToNodeMap);
|
||||
|
||||
template <typename ElementType>
|
||||
static Microsoft::MSR::CNTK::ComputationNodeBasePtr GetOutputVariableNode(const Variable& variable,
|
||||
Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
|
||||
Microsoft::MSR::CNTK::ComputationNetworkBuilder<ElementType>& builder,
|
||||
std::unordered_map<Variable, Microsoft::MSR::CNTK::ComputationNodeBasePtr>& variableToNodeMap,
|
||||
std::unordered_map<Variable, bool>& isVariableRootMap);
|
||||
|
||||
template <typename ElementType>
|
||||
static Microsoft::MSR::CNTK::ComputationNodeBasePtr GetNode(const Variable& variable, Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
|
||||
Microsoft::MSR::CNTK::ComputationNetworkBuilder<ElementType>& builder,
|
||||
std::unordered_map<Variable, Microsoft::MSR::CNTK::ComputationNodeBasePtr>& variableToNodeMap,
|
||||
std::unordered_map<Variable, bool>& isVariableRootMap);
|
||||
|
||||
template <typename ElementType>
|
||||
static void PopulateComputationNodeValue(const std::pair<Variable, ValuePtr>& variableValue, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode);
|
||||
void PopulateNetworkInputs(const std::unordered_map<Variable, ValuePtr>& arguments);
|
||||
|
||||
template <typename ElementType>
|
||||
static void PopulateComputationNodeGradient(const std::pair<Variable, ValuePtr>& variableGradient, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode);
|
||||
void PopulateNetworkGradients(const std::unordered_map<Variable, ValuePtr>& gradients);
|
||||
|
||||
static void GetNodeOutputOrGradient(Variable var, ValuePtr& varValue, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode, bool getGradient);
|
||||
void GetNetworkOutputs(std::unordered_map<Variable, ValuePtr>& outputs);
|
||||
void GetNetworkGradients(std::unordered_map<Variable, ValuePtr>& gradients);
|
||||
|
||||
template <typename ElementType>
|
||||
static std::pair<std::shared_ptr<const Microsoft::MSR::CNTK::Matrix<ElementType>>, Microsoft::MSR::CNTK::MBLayoutPtr> GetCNTKImplMatrixAndMBLayoutFromValueObject(Variable var, const ValuePtr& value);
|
||||
|
||||
template <typename ElementType>
|
||||
static ValuePtr GetValueObjectFromCNTKImplMatrixAndMBLayout(const NDShape& sampleShape, const Microsoft::MSR::CNTK::Matrix<ElementType>& matrix, const Microsoft::MSR::CNTK::MBLayoutPtr& layout, bool readOnly = true);
|
||||
template <typename ElementType>
|
||||
static ValuePtr GetValueObjectFromCNTKImplMatrixAndMBLayout(Variable var, const Microsoft::MSR::CNTK::Matrix<ElementType>& matrix, const Microsoft::MSR::CNTK::MBLayoutPtr& layout, bool readOnly = true);
|
||||
|
||||
const std::vector<Variable>& GetArgumentDependencies(const Variable& output);
|
||||
|
||||
private:
|
||||
|
||||
// Set of all primitive functions in the graph underlying 'this' Function. Also keeps the primitive Function objects alive
|
||||
// by holding strong references to them
|
||||
std::unordered_set<FunctionPtr> m_allPrimitiveFunctions;
|
||||
|
||||
// A map from Variable objects to ComputationNode objects in the ComputationNetwork instance that implements 'this' Composite Function
|
||||
std::unordered_map<Variable, Microsoft::MSR::CNTK::ComputationNodeBasePtr> m_variableToNodeMap;
|
||||
|
||||
// A map that tells whether a Variable in the graph underlying 'this' Function is a root of the graph
|
||||
std::unordered_map<Variable, bool> m_isVariableRootMap;
|
||||
|
||||
Microsoft::MSR::CNTK::ComputationNetworkPtr m_computationNetwork;
|
||||
|
||||
// The backpropRoots sepecified in the most recent 'Forward' call on 'this' Function.
|
||||
// This indicates for which of its roots has 'this' Function retained required intermediate
|
||||
// states from the previos Forward call to be able to backpropagate gradients backwards from in
|
||||
// the next 'Backward' call.
|
||||
std::unordered_set<Variable> m_currentBackpropRoots;
|
||||
|
||||
std::unordered_map<Variable, std::vector<Variable>> m_perOutputVarArgumentDependencies;
|
||||
|
||||
bool m_networkMatricesAllocated;
|
||||
|
||||
std::unordered_map<Parameter, size_t> m_lastRecordedParameterValueTimeStamps;
|
||||
|
||||
static const size_t s_serializationVersion = 1;
|
||||
};
|
||||
|
||||
inline std::vector<CNTK::Axis> DynamicAxesFromInternalDynamicAxisName(const std::wstring& internalDynamicAxisName)
|
||||
{
|
||||
std::vector<CNTK::Axis> inputVarDynamicAxes;
|
||||
if (internalDynamicAxisName.substr(0, CNTK::CompositeFunction::InternalDefaultDynamicAxisName.length()) == CNTK::CompositeFunction::InternalDefaultDynamicAxisName)
|
||||
inputVarDynamicAxes = { CNTK::Axis::DefaultDynamicAxis(), CNTK::Axis::DefaultBatchAxis() };
|
||||
else if (internalDynamicAxisName.substr(0, CNTK::CompositeFunction::InternalNoSequenceAxisName.length()) == CNTK::CompositeFunction::InternalNoSequenceAxisName)
|
||||
inputVarDynamicAxes = { CNTK::Axis::DefaultBatchAxis() };
|
||||
else
|
||||
inputVarDynamicAxes = { CNTK::Axis(internalDynamicAxisName), CNTK::Axis::DefaultBatchAxis() };
|
||||
|
||||
return inputVarDynamicAxes;
|
||||
}
|
||||
|
||||
// Construct the dynamic axis name to be used internally for the CNTK InputNodes
|
||||
inline std::wstring InternalDynamicAxisNameFromDynamicAxes(const std::vector<Axis>& dynamicAxes)
|
||||
{
|
||||
if (dynamicAxes.empty())
|
||||
LogicError("Empty dynamic axes set");
|
||||
|
||||
if (dynamicAxes == std::vector<Axis>({ Axis::DefaultBatchAxis() }))
|
||||
return CompositeFunction::InternalNoSequenceAxisName;
|
||||
else if (dynamicAxes == std::vector<Axis>({ Axis::DefaultDynamicAxis(), Axis::DefaultBatchAxis() }))
|
||||
return CompositeFunction::InternalDefaultDynamicAxisName;
|
||||
else
|
||||
return dynamicAxes[0].Name();
|
||||
}
|
||||
}
|
|
@ -6,7 +6,7 @@
|
|||
#include "stdafx.h"
|
||||
#include "CNTKLibrary.h"
|
||||
#include "Utils.h"
|
||||
#include "Function.h"
|
||||
#include "CompositeFunction.h"
|
||||
#include <tuple>
|
||||
#include "ComputationNetworkBuilder.h"
|
||||
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -10,7 +10,6 @@
|
|||
#include "MinibatchSource.h"
|
||||
#include "HeapMemoryProvider.h"
|
||||
#include "ReaderShim.h"
|
||||
#include "Function.h"
|
||||
#include <tuple>
|
||||
#include "Value.h"
|
||||
#include "MPIWrapper.h"
|
||||
|
|
|
@ -0,0 +1,654 @@
|
|||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
|
||||
#include "stdafx.h"
|
||||
#include "PrimitiveFunction.h"
|
||||
#include "ComputationNode.h"
|
||||
#include "ReshapingNodes.h"
|
||||
#include "EvaluationNodes.h"
|
||||
#include "TrainingNodes.h"
|
||||
#include "LinearAlgebraNodes.h"
|
||||
#include "InputAndParamNodes.h"
|
||||
#include "NonlinearityNodes.h"
|
||||
#include "RecurrentNodes.h"
|
||||
#include "Serialization.h"
|
||||
#include "RNNNodes.h"
|
||||
|
||||
using namespace Microsoft::MSR::CNTK;
|
||||
|
||||
namespace CNTK
|
||||
{
|
||||
// Names for the reduction operations as used by the CNTK ReduceElementsNode
|
||||
/*static*/ const std::wstring PrimitiveFunction::InternalSumReductionOpName = L"Sum";
|
||||
/*static*/ const std::wstring PrimitiveFunction::InternalLogSumReductionOpName = L"LogSum";
|
||||
/*static*/ const std::wstring PrimitiveFunction::InternalMeanReductionOpName = L"Mean";
|
||||
/*static*/ const std::wstring PrimitiveFunction::InternalMaxReductionOpName = L"Max";
|
||||
/*static*/ const std::wstring PrimitiveFunction::InternalMinReductionOpName = L"Min";
|
||||
/*static*/ const std::wstring PrimitiveFunction::InternalAllReductionOpName = L"All";
|
||||
/*static*/ const std::wstring PrimitiveFunction::InternalAnyReductionOpName = L"Any";
|
||||
|
||||
// Names of the various attributes of CNTK primitive Functions
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameAxis = L"axis";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameAxis1 = L"axis1";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameAxis2 = L"axis2";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameAllowDuplicates = L"allowDuplicates";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameNumSamples = L"numSamples";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameDropoutRate = L"dropoutRate";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameNewShape = L"newShape";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameOutputRank = L"outputRank";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameInferInputRankToMap = L"inferInputRankToMap";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameOffset = L"offset";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameStrides = L"strides";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameSharing = L"sharing";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameAutoPadding = L"autoPadding";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameLowerPad = L"lowerPad";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameUpperPad = L"upperPad";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameTranspose = L"transpose";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameMaxTempMemSizeInSamples = L"maxTempMemSizeInSamples";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameROIOutputShape = L"roiOutputShape";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNamePoolingType = L"poolingType";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNamePoolingWindowShape = L"poolingWindowShape";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameSpatial = L"spatial";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameNormalizationTimeConstant = L"normalizationTimeConstant";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameBlendTimeConstant = L"blendTimeConstant";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameEpsilon = L"epsilon";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameUseCuDNNEngine = L"useCuDNNEngine";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameNewDynamicAxes = L"newDynamicAxes";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameNewSequenceAxisLengthScalingFactor = L"newSequenceAxisLengthScalingFactor";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameNewSequenceAxisLengthAdditiveFactor = L"newSequenceAxisLengthAdditiveFactor";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameBeginIndex = L"beginIndex";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameEndIndex = L"endIndex";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameReductionOpName = L"reductionOpName";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameBidirectional = L"bidirectional";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameNumLayers = L"numLayers";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameHiddenSize = L"hiddenSize";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameRecurrentOp = L"recurrentOp";
|
||||
|
||||
/*static*/ std::vector<Variable> PrimitiveFunction::GetOutputVariables(PrimitiveOpType op,
|
||||
std::vector<Variable>& inputs,
|
||||
Function* owner,
|
||||
Dictionary& functionConfig,
|
||||
bool inferDimensions,
|
||||
const std::wstring& functionName)
|
||||
{
|
||||
if (op == PrimitiveOpType::Combine)
|
||||
return inputs;
|
||||
|
||||
// We use the first non-constant input operand's DataType as the output DataType
|
||||
// In case there are no non-constant known DataTypes, we just pick the first known operand DataType
|
||||
// Also, all the known DataTypes of operands should match except for constants where coercion is allowed
|
||||
DataType firstKnownInputDataType = DataType::Unknown;
|
||||
DataType outputDataType = DataType::Unknown;
|
||||
size_t i = 0;
|
||||
while (i < inputs.size())
|
||||
{
|
||||
auto input = inputs[i++];
|
||||
auto inputDataType = input.GetDataType();
|
||||
if (inputDataType != DataType::Unknown)
|
||||
{
|
||||
if (firstKnownInputDataType == DataType::Unknown)
|
||||
firstKnownInputDataType = inputDataType;
|
||||
|
||||
if (outputDataType == DataType::Unknown)
|
||||
{
|
||||
if (!input.IsConstant())
|
||||
outputDataType = inputDataType;
|
||||
}
|
||||
else
|
||||
{
|
||||
// The DataType of all operands should match except for Constants where we allow coercion
|
||||
if ((inputDataType != DataType::Unknown) && (inputDataType != outputDataType) && !input.IsConstant())
|
||||
InvalidArgument("Primitive function with op type %S has operands with different DataTypes %s and %s", PrimitiveOpTypeName(op).c_str(), DataTypeName(outputDataType), DataTypeName(inputDataType));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (outputDataType == DataType::Unknown)
|
||||
outputDataType = firstKnownInputDataType;
|
||||
|
||||
// We currently require that the inputs' dynamic axes, if any, match
|
||||
std::vector<Axis> outputDynamicAxes;
|
||||
if ((op == PrimitiveOpType::SumAll) ||
|
||||
(op == PrimitiveOpType::SquaredError) ||
|
||||
(op == PrimitiveOpType::CrossEntropyWithSoftmax) ||
|
||||
(op == PrimitiveOpType::ClassificationError) ||
|
||||
(op == PrimitiveOpType::Logistic))
|
||||
{
|
||||
outputDynamicAxes = std::vector<Axis>({});
|
||||
}
|
||||
else if (op == PrimitiveOpType::Where)
|
||||
{
|
||||
if (functionConfig.Contains(PrimitiveFunction::AttributeNameNewDynamicAxes))
|
||||
outputDynamicAxes = AsVector<Axis>(functionConfig[PrimitiveFunction::AttributeNameNewDynamicAxes].Value<std::vector<DictionaryValue>>());
|
||||
else
|
||||
{
|
||||
if (inputs[0].DynamicAxes() == Axis::UnknownDynamicAxes())
|
||||
outputDynamicAxes = Axis::UnknownDynamicAxes();
|
||||
else
|
||||
{
|
||||
if (functionConfig.Contains(PrimitiveFunction::AttributeNameNewSequenceAxisLengthScalingFactor) &&
|
||||
functionConfig.Contains(PrimitiveFunction::AttributeNameNewSequenceAxisLengthAdditiveFactor))
|
||||
{
|
||||
size_t newSequenceAxisLengthScalingFactor = functionConfig[PrimitiveFunction::AttributeNameNewSequenceAxisLengthScalingFactor].Value<size_t>();
|
||||
int newSequenceAxisLengthAdditiveFactor = functionConfig[PrimitiveFunction::AttributeNameNewSequenceAxisLengthAdditiveFactor].Value<int>();
|
||||
|
||||
auto derivedDynamicAxes = GetDerivedDynamicAxes(inputs[0].DynamicAxes()[0], newSequenceAxisLengthScalingFactor, newSequenceAxisLengthAdditiveFactor);
|
||||
std::copy(derivedDynamicAxes.begin(), derivedDynamicAxes.end(), std::back_inserter(outputDynamicAxes));
|
||||
}
|
||||
else
|
||||
{
|
||||
outputDynamicAxes.push_back(Axis::NewUniqueDynamicAxis(L"whereNodeDynamicAxis"));
|
||||
}
|
||||
|
||||
for (size_t i = 1; i < inputs[0].DynamicAxes().size(); ++i)
|
||||
outputDynamicAxes.push_back(inputs[0].DynamicAxes()[i]);
|
||||
|
||||
functionConfig[PrimitiveFunction::AttributeNameNewDynamicAxes] = AsDictionaryValueVector(outputDynamicAxes);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (op == PrimitiveOpType::ScatterPacked)
|
||||
outputDynamicAxes = inputs[2].DynamicAxes();
|
||||
else if ((op == PrimitiveOpType::PackedIndex) || (op == PrimitiveOpType::GatherPacked))
|
||||
outputDynamicAxes = inputs[1].DynamicAxes();
|
||||
else if (op == PrimitiveOpType::ReconcileDynamicAxis)
|
||||
outputDynamicAxes = inputs[1].DynamicAxes();
|
||||
else
|
||||
{
|
||||
auto allInputDynamicAxesEmpty = std::find_if(inputs.begin(), inputs.end(), [](const Variable& input) { return !input.DynamicAxes().empty(); }) == inputs.end();
|
||||
if (!allInputDynamicAxesEmpty)
|
||||
{
|
||||
outputDynamicAxes = Axis::UnknownDynamicAxes();
|
||||
for (auto inputVar : inputs)
|
||||
{
|
||||
auto currentInputDynamicAxes = inputVar.DynamicAxes();
|
||||
if (!currentInputDynamicAxes.empty() && (currentInputDynamicAxes != Axis::UnknownDynamicAxes()))
|
||||
{
|
||||
if (outputDynamicAxes == Axis::UnknownDynamicAxes())
|
||||
outputDynamicAxes = currentInputDynamicAxes;
|
||||
else
|
||||
{
|
||||
if (currentInputDynamicAxes != outputDynamicAxes)
|
||||
LogicError("Currently if an operand of a elementwise operation has any dynamic axes, those must match the dynamic axes of the other operands");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
NDShape outputShape;
|
||||
bool areAnyInputShapesUnknown = (std::find_if(inputs.begin(), inputs.end(), [](const Variable& input) { return input.Shape().IsUnknown(); }) != inputs.end());
|
||||
if (areAnyInputShapesUnknown)
|
||||
outputShape = NDShape::Unknown;
|
||||
else
|
||||
{
|
||||
switch (op)
|
||||
{
|
||||
case PrimitiveOpType::Negate:
|
||||
case PrimitiveOpType::Sigmoid:
|
||||
case PrimitiveOpType::Tanh:
|
||||
case PrimitiveOpType::ReLU:
|
||||
case PrimitiveOpType::Exp:
|
||||
case PrimitiveOpType::Log:
|
||||
case PrimitiveOpType::Sqrt:
|
||||
case PrimitiveOpType::Floor:
|
||||
case PrimitiveOpType::Abs:
|
||||
case PrimitiveOpType::Reciprocal:
|
||||
case PrimitiveOpType::Softmax:
|
||||
case PrimitiveOpType::Hardmax:
|
||||
case PrimitiveOpType::Dropout:
|
||||
case PrimitiveOpType::Where:
|
||||
case PrimitiveOpType::LogSoftmax:
|
||||
{
|
||||
assert(inputs.size() == 1);
|
||||
outputShape = UnaryElementwiseOpOutputShape(inputs[0].Shape());
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::TransposeAxes:
|
||||
{
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
auto axis1 = NormalizeStaticAxis(functionConfig[PrimitiveFunction::AttributeNameAxis1].Value<Axis>(), inputs[0].Shape());
|
||||
auto axis2 = NormalizeStaticAxis(functionConfig[PrimitiveFunction::AttributeNameAxis2].Value<Axis>(), inputs[0].Shape());
|
||||
|
||||
if (!axis1.IsStaticAxis() || !axis2.IsStaticAxis())
|
||||
LogicError("TransposeAxes operation currently does not support transposing dynamic axes");
|
||||
|
||||
VerifyStaticAxis(axis1, inputs[0].Shape());
|
||||
VerifyStaticAxis(axis2, inputs[0].Shape());
|
||||
|
||||
outputShape = inputs[0].Shape();
|
||||
std::swap(outputShape[axis1.StaticAxisIndex()], outputShape[axis2.StaticAxisIndex()]);
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::Slice:
|
||||
{
|
||||
assert(inputs.size() == 1);
|
||||
auto axis = NormalizeStaticAxis(functionConfig[PrimitiveFunction::AttributeNameAxis].Value<Axis>(), inputs[0].Shape());
|
||||
|
||||
auto beginIndex = functionConfig[PrimitiveFunction::AttributeNameBeginIndex].Value<int>();
|
||||
auto endIndex = functionConfig[PrimitiveFunction::AttributeNameEndIndex].Value<int>();
|
||||
if (!axis.IsStaticAxis())
|
||||
LogicError("Built-in Slice operation currently does not support slicing along dynamic axis");
|
||||
|
||||
VerifyStaticAxis(axis, inputs[0].Shape());
|
||||
|
||||
size_t sliceAxisDim = inputs[0].Shape()[axis.StaticAxisIndex()];
|
||||
int realBeginIndex = (beginIndex >= 0) ? beginIndex : beginIndex + sliceAxisDim;
|
||||
int realEndIndex = (endIndex > 0) ? endIndex : endIndex + sliceAxisDim;
|
||||
if ((sliceAxisDim < realEndIndex) || (realEndIndex < realBeginIndex) || (realBeginIndex < 0))
|
||||
RuntimeError("Slice operation: Index range [%d,%d), interpreted as [%d,%d), is invalid for input's shape ([%S]).",
|
||||
beginIndex,
|
||||
endIndex,
|
||||
realBeginIndex,
|
||||
realEndIndex,
|
||||
AsStringForErrorReporting(inputs[0].Shape()).c_str());
|
||||
|
||||
auto outputTensorShape = AsTensorShape(inputs[0].Shape());
|
||||
|
||||
// propagate as much as we can
|
||||
if ((axis.StaticAxisIndex() < (int)outputTensorShape.GetRank()) && (0 <= realBeginIndex) && (realBeginIndex <= realEndIndex) && (realEndIndex <= sliceAxisDim))
|
||||
outputTensorShape.NarrowTo(axis.StaticAxisIndex(), realBeginIndex, realEndIndex);
|
||||
|
||||
outputShape = AsNDShape(outputTensorShape, /*allowNonFlattenableTensorShapes = */ true);
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::Reshape:
|
||||
{
|
||||
auto newShape = functionConfig[PrimitiveFunction::AttributeNameNewShape].Value<NDShape>();
|
||||
outputShape = ReshapeOutputShape(inputs[0].Shape(), newShape);
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::ROIPooling:
|
||||
{
|
||||
assert(inputs.size() == 2);
|
||||
auto convMapShape = inputs[0].Shape();
|
||||
auto roisShape = inputs[1].Shape();
|
||||
auto roiOutputShape = functionConfig[PrimitiveFunction::AttributeNameROIOutputShape].Value<NDShape>();
|
||||
|
||||
auto outW = roiOutputShape[0];
|
||||
auto outH = roiOutputShape[1];
|
||||
auto numChannels = convMapShape[2];
|
||||
auto roisPerImage = roisShape[1];
|
||||
|
||||
if (roiOutputShape.Rank() != 2)
|
||||
InvalidArgument("ROIPoolingNode: roi output shape must have two dimensions ([W x H]).");
|
||||
|
||||
if (convMapShape[0] < outW || convMapShape[1] < outH)
|
||||
InvalidArgument("ROIPoolingNode: inputWidth must >= windowWidth and inputHeight must >= windowHeight.");
|
||||
|
||||
if (convMapShape[2] < 1)
|
||||
InvalidArgument("ROIPoolingNode: input must have at least one channel ([W x H x C]).");
|
||||
|
||||
if (roisShape[0] != 4)
|
||||
InvalidArgument("ROIPoolingNode: ROI input must have the following shape: [4 x roisPerImage].");
|
||||
|
||||
if (roisPerImage < 1)
|
||||
InvalidArgument("ROIPoolingNode: ROI input must contain at least one ROI ([4 x roisPerImage]).");
|
||||
|
||||
outputShape = { outW, outH, numChannels, roisPerImage };
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::Pooling:
|
||||
{
|
||||
assert(inputs.size() == 1);
|
||||
auto poolingWindowsShape = functionConfig[PrimitiveFunction::AttributeNamePoolingWindowShape].Value<NDShape>();
|
||||
auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value<NDShape>();
|
||||
auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value<NDShape>();
|
||||
auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value<NDShape>();
|
||||
auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
|
||||
NDShape outputMapCount = { 1 };
|
||||
std::vector<bool> sharing = { true };
|
||||
auto inputShape = inputs[0].Shape();
|
||||
|
||||
// In case of pooling if the kernel shape is unknown, then treat it as global pooling.
|
||||
if (poolingWindowsShape == NDShape::Unknown)
|
||||
{
|
||||
if ((std::find(autoPadding.begin(), autoPadding.end(), true) != autoPadding.end()) ||
|
||||
(lowerPad.TotalSize() > 0) || (upperPad.TotalSize() > 0))
|
||||
RuntimeError("Padding isn't allowed for Unknown shape!");
|
||||
|
||||
poolingWindowsShape = inputShape.SubShape(0, inputShape.Rank()-1);
|
||||
functionConfig[PrimitiveFunction::AttributeNamePoolingWindowShape] = poolingWindowsShape;
|
||||
}
|
||||
|
||||
outputShape = ConvolutionOpOutputShape(op, inputShape, poolingWindowsShape, outputMapCount, strides, sharing, autoPadding, lowerPad, upperPad, false, inferDimensions);
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::SumAll:
|
||||
assert(inputs.size() == 1);
|
||||
outputShape = {1};
|
||||
break;
|
||||
case PrimitiveOpType::Plus:
|
||||
case PrimitiveOpType::LogPlus:
|
||||
case PrimitiveOpType::Minus:
|
||||
case PrimitiveOpType::ElementTimes:
|
||||
case PrimitiveOpType::Equal:
|
||||
case PrimitiveOpType::NotEqual:
|
||||
case PrimitiveOpType::Less:
|
||||
case PrimitiveOpType::LessEqual:
|
||||
case PrimitiveOpType::Greater:
|
||||
case PrimitiveOpType::GreaterEqual:
|
||||
case PrimitiveOpType::PastValue:
|
||||
case PrimitiveOpType::FutureValue:
|
||||
{
|
||||
assert(inputs.size() == 2);
|
||||
if ((op == PrimitiveOpType::PastValue) || (op == PrimitiveOpType::FutureValue))
|
||||
{
|
||||
Variable inputOperandVar = inputs[0];
|
||||
Variable initialStateVar = inputs[1];
|
||||
|
||||
// TODO: We currently only support input operand with 1 dynamic axis for PastValue/FutureValue
|
||||
if ((inputOperandVar.DynamicAxes() != Axis::UnknownDynamicAxes()) && (inputOperandVar.DynamicAxes().size() != 2))
|
||||
LogicError("Currently PastValue/FutureValue Function only supports input operand with 2 dynamic axis (1 sequence-axis and 1 batch-axis)");
|
||||
|
||||
if (!initialStateVar.DynamicAxes().empty())
|
||||
LogicError("Currently PastValue/FutureValue Function does not support initial state operand with dynamic axes!");
|
||||
}
|
||||
|
||||
outputShape = BinaryElementwiseOpOutputShape(op, inputs[0], inputs[1], true, inferDimensions);
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::Times:
|
||||
{
|
||||
assert(inputs.size() == 2);
|
||||
auto outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value<size_t>();
|
||||
auto inferInputRankToMap = functionConfig[PrimitiveFunction::AttributeNameInferInputRankToMap].Value<int>();
|
||||
outputShape = TimesOpOutputShape(inputs[0], inputs[1], outputRank, inferInputRankToMap, inferDimensions);
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::TransposeTimes:
|
||||
{
|
||||
assert(inputs.size() == 2);
|
||||
|
||||
auto transposeShapeFunc = [](const NDShape& shape) {
|
||||
NDShape transposedShape(std::max<size_t>(2, shape.Rank()), 1);
|
||||
for (size_t i = 0; i < shape.Rank(); ++i)
|
||||
transposedShape[transposedShape.Rank() - i - 1] = shape[i];
|
||||
|
||||
return transposedShape;
|
||||
};
|
||||
|
||||
if (inputs[0].Shape().Rank() > 2)
|
||||
LogicError("TransposeTimes operation currently only supports %s operands of rank 1 or 2", Internal::IsReversingTensorShapesInErrorMessagesEnabled() ? "right" : "left");
|
||||
|
||||
NDShape transposedLeftOperandShape = transposeShapeFunc(inputs[0].Shape());
|
||||
Variable dummyLeftOperand = PlaceholderVariable(transposedLeftOperandShape);
|
||||
size_t outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value<size_t>();
|
||||
outputShape = TimesOpOutputShape(dummyLeftOperand, inputs[1], outputRank, -1, inferDimensions);
|
||||
if (dummyLeftOperand.Shape() != transposedLeftOperandShape)
|
||||
inputs[0].m_dataFields->m_shape = transposeShapeFunc(dummyLeftOperand.Shape());
|
||||
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::Convolution:
|
||||
{
|
||||
assert(inputs.size() == 2);
|
||||
auto& strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value<NDShape>();
|
||||
auto& lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value<NDShape>();
|
||||
auto& upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value<NDShape>();
|
||||
auto sharing = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameSharing].Value<std::vector<DictionaryValue>>());
|
||||
auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
|
||||
bool transpose = functionConfig[PrimitiveFunction::AttributeNameTranspose].Value<bool>();
|
||||
if (inputs[0].Shape().Rank() < inputs[1].Shape().Rank())
|
||||
InvalidArgument("The convolution map should have at least as many axes as the shape of the input it operates on!");
|
||||
|
||||
NDShape outputMapCount, kernelShape;
|
||||
std::tie(outputMapCount, kernelShape) = GetConvolutionOutputMapCountAndKernelShape(inputs[0].Shape(), inputs[1].Shape());
|
||||
auto originalKernelShape = kernelShape;
|
||||
outputShape = ConvolutionOpOutputShape(op, inputs[1].Shape(), kernelShape, outputMapCount, strides, sharing, autoPadding, lowerPad, upperPad, transpose, inferDimensions);
|
||||
if (originalKernelShape != kernelShape)
|
||||
{
|
||||
for (size_t i = 0; i < kernelShape.Rank(); ++i)
|
||||
inputs[0].m_dataFields->m_shape[i] = kernelShape[i];
|
||||
}
|
||||
|
||||
functionConfig[PrimitiveFunction::AttributeNameSharing] = AsDictionaryValueVector(sharing);
|
||||
functionConfig[PrimitiveFunction::AttributeNameAutoPadding] = AsDictionaryValueVector(autoPadding);
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::Logistic:
|
||||
case PrimitiveOpType::SquaredError:
|
||||
case PrimitiveOpType::CrossEntropyWithSoftmax:
|
||||
case PrimitiveOpType::ClassificationError:
|
||||
{
|
||||
if ((op == PrimitiveOpType::ClassificationError) || (op == PrimitiveOpType::Logistic))
|
||||
assert(inputs.size() >= 2);
|
||||
else
|
||||
assert(inputs.size() == 2);
|
||||
|
||||
if ((inputs[0].Shape().Rank() > 2) || ((inputs[0].Shape().Rank() > 1) && (inputs[0].Shape()[1] != 1)))
|
||||
InvalidArgument("The shape of input operands for the %S operation should have at most one axis", PrimitiveOpTypeName(op).c_str());
|
||||
|
||||
auto predictionShape = inputs[0].Shape();
|
||||
auto labelsShape = inputs[1].Shape();
|
||||
if (predictionShape != labelsShape)
|
||||
RuntimeError("Prediction output operand's shape %S is incompatible with label operand's shape %S for the %S operation", AsStringForErrorReporting(predictionShape).c_str(), AsStringForErrorReporting(labelsShape).c_str(), PrimitiveOpTypeName(op).c_str());
|
||||
|
||||
std::vector<int> reductionAxes;
|
||||
for (int i = 0; i < (int)inputs[0].Shape().Rank(); ++i)
|
||||
reductionAxes.push_back(i);
|
||||
|
||||
outputShape = ReductionOpOutputShape(op, predictionShape, reductionAxes, /*preserveReductionAxes =*/ false);
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::ReduceElements:
|
||||
{
|
||||
assert(inputs.size() == 1);
|
||||
auto reductionAxis = NormalizeStaticAxis(functionConfig[PrimitiveFunction::AttributeNameAxis].Value<Axis>(), inputs[0].Shape());
|
||||
if (reductionAxis == Axis::AllStaticAxes())
|
||||
outputShape = {};
|
||||
else
|
||||
{
|
||||
std::vector<int> reductionAxes = { reductionAxis.StaticAxisIndex() };
|
||||
outputShape = ReductionOpOutputShape(op, inputs[0].Shape(), reductionAxes, /*preserveReductionAxes =*/ true);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::BatchNormalization:
|
||||
{
|
||||
assert(inputs.size() == 5);
|
||||
auto spatial = functionConfig[PrimitiveFunction::AttributeNameSpatial].Value<bool>();
|
||||
outputShape = BatchNormalizationOutputShape(inputs, spatial, inferDimensions);
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::PackedIndex:
|
||||
outputShape = UnaryElementwiseOpOutputShape(inputs[1].Shape());
|
||||
break;
|
||||
case PrimitiveOpType::GatherPacked:
|
||||
{
|
||||
bool sourceHasDynamicAxis = !inputs[0].DynamicAxes().empty();
|
||||
|
||||
// inherit tensor dimension from sourceData, minus the last (column or time) dimension. TODO this needs to become simpler...
|
||||
if (sourceHasDynamicAxis)
|
||||
outputShape = inputs[0].Shape();
|
||||
else
|
||||
{
|
||||
if (inputs[0].Shape().Rank() > 1)
|
||||
outputShape = outputShape.SubShape(0, outputShape.Rank() - 1);
|
||||
else
|
||||
outputShape = {};
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::ScatterPacked:
|
||||
{
|
||||
if (inputs[0].DynamicAxes().empty() || inputs[1].DynamicAxes().empty() || inputs[2].DynamicAxes().empty())
|
||||
InvalidArgument("ScatterPacked requires all its operands to have dynamic axes");
|
||||
|
||||
outputShape = inputs[0].Shape();
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::Clip:
|
||||
assert(inputs.size() == 3);
|
||||
outputShape = UnaryElementwiseOpOutputShape(inputs[0].Shape());
|
||||
break;
|
||||
case PrimitiveOpType::Select:
|
||||
assert(inputs.size() == 3);
|
||||
outputShape = NaryElementwiseOpOutputShape(op, inputs, true);
|
||||
break;
|
||||
case PrimitiveOpType::Splice:
|
||||
{
|
||||
assert(inputs.size() >= 2);
|
||||
auto maxInputRank = MaxInputRank(inputs);
|
||||
auto spliceAxis = NormalizeStaticAxis(functionConfig[PrimitiveFunction::AttributeNameAxis].Value<Axis>(), NDShape(maxInputRank));
|
||||
|
||||
if (!spliceAxis.IsStaticAxis())
|
||||
LogicError("Splice operation currently does not support splicing along dynamic axis");
|
||||
|
||||
if (spliceAxis.StaticAxisIndex() < 0)
|
||||
InvalidArgument("Splice: The axis argument's static axis index must be >= 0!");
|
||||
|
||||
outputShape = SpliceOutputShape(inputs, spliceAxis.StaticAxisIndex());
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::RandomSample:
|
||||
case PrimitiveOpType::RandomSampleInclusionFrequency:
|
||||
{
|
||||
auto numSamples = functionConfig[PrimitiveFunction::AttributeNameNumSamples].Value<size_t>();
|
||||
auto allowDuplicates = functionConfig[PrimitiveFunction::AttributeNameAllowDuplicates].Value<bool>();
|
||||
|
||||
if (numSamples == 0)
|
||||
InvalidArgument("Number of requested samples is zero.");
|
||||
|
||||
let& shape = inputs[0].Shape();
|
||||
size_t numClasses = shape.Dimensions()[0];
|
||||
|
||||
if (numClasses != NDShape::InferredDimension && !allowDuplicates && numClasses <= numSamples)
|
||||
InvalidArgument("For sampling without duplicates the number of requested samples (%lu) needs to be less than the number of classes (%lu).", numSamples, numClasses);
|
||||
|
||||
// within this block we handle RandomSample and RandomSampleInclusionFrequency
|
||||
if (op == PrimitiveOpType::RandomSampleInclusionFrequency)
|
||||
outputShape = shape;
|
||||
else
|
||||
{
|
||||
vector<size_t> dimensions{ numSamples, numClasses };
|
||||
outputShape = NDShape(dimensions);
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::OptimizedRNNStack:
|
||||
{
|
||||
assert(inputs.size() == 2);
|
||||
auto operand = inputs[0];
|
||||
auto parameter = inputs[1];
|
||||
if (operand.Shape().Rank() != 1)
|
||||
InvalidArgument("OptimizedRNNStack: input must have rank 1; actual input rank is %lu", operand.Shape().Rank());
|
||||
if (operand.DynamicAxes().empty())
|
||||
InvalidArgument("OptimizedRNNStack: input must have at least one dynamic axis");
|
||||
auto numLayers = functionConfig[PrimitiveFunction::AttributeNameNumLayers].Value<size_t>();
|
||||
if (numLayers == 0)
|
||||
InvalidArgument("Number of layers in OptimizedRNNStack operation should be positive");
|
||||
auto bidirectional = functionConfig[PrimitiveFunction::AttributeNameBidirectional].Value<bool>();
|
||||
auto hiddenSize = functionConfig[PrimitiveFunction::AttributeNameHiddenSize].Value<size_t>();
|
||||
|
||||
// output dims
|
||||
outputShape = operand.Shape();
|
||||
outputShape[0] = (bidirectional ? 2 : 1) * hiddenSize;
|
||||
// infer input size
|
||||
// Note: Output dim is second axis, so say initOutputRank=-1.
|
||||
if (parameter.Shape().Rank() == 2)
|
||||
{
|
||||
const auto recurrentOp = functionConfig[PrimitiveFunction::AttributeNameRecurrentOp].Value<std::wstring>();
|
||||
const auto attributes = RnnAttributes(bidirectional, numLayers, hiddenSize, recurrentOp, -1);
|
||||
const auto numParameters = attributes.GetNumParameters(operand.Shape().TotalSize());
|
||||
std::vector<std::pair<Variable, NDShape>> newOperandShapes = { { parameter, std::move(NDShape({ numParameters.first, numParameters.second })) } };
|
||||
UpdateOperandShapes(newOperandShapes);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::ReconcileDynamicAxis:
|
||||
{
|
||||
assert(inputs.size() == 2);
|
||||
auto operand = inputs[0];
|
||||
auto layout = inputs[1];
|
||||
if (operand.DynamicAxes().empty())
|
||||
InvalidArgument("ReconcileDynamicAxis: input must have at least one dynamic axis");
|
||||
if (layout.DynamicAxes().empty())
|
||||
InvalidArgument("ReconcileDynamicAxis: layout must have at least one dynamic axis");
|
||||
outputShape = operand.Shape();
|
||||
break;
|
||||
}
|
||||
default:
|
||||
LogicError("Specified op %S not yet supported", PrimitiveOpTypeName(op).c_str());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return{ OutputVariable(outputShape, outputDataType, owner, outputDynamicAxes, functionName.empty() ? L"" : functionName + L"_output") };
|
||||
}
|
||||
|
||||
static const std::wstring s_primitiveFunctionTypeValue = L"PrimitiveFunction";
|
||||
|
||||
/*virtual*/ Dictionary PrimitiveFunction::Serialize() const
|
||||
{
|
||||
Dictionary dict;
|
||||
|
||||
dict[versionKey] = CurrentVersion();
|
||||
dict[typeKey] = s_primitiveFunctionTypeValue;
|
||||
dict[opKey] = static_cast<size_t>(m_op);
|
||||
dict[attributesKey] = Attributes();
|
||||
dict[uidKey] = Uid();
|
||||
dict[nameKey] = Name();
|
||||
|
||||
auto inputs = Inputs();
|
||||
vector<DictionaryValue> inputUids;
|
||||
inputUids.reserve(inputs.size());
|
||||
for (auto& input : inputs)
|
||||
{
|
||||
inputUids.push_back(input.Uid());
|
||||
}
|
||||
|
||||
dict[inputsKey] = std::move(inputUids);
|
||||
|
||||
return dict;
|
||||
}
|
||||
|
||||
/*static*/ FunctionPtr PrimitiveFunction::Deserialize(const Dictionary& dict,
|
||||
const std::unordered_map<std::wstring, Variable>& uidToVariableMap,
|
||||
const CNTK::DeviceDescriptor& device)
|
||||
{
|
||||
static const vector<std::wstring> s_requiredDictionaryKeys = { typeKey, opKey, uidKey, attributesKey, inputsKey, nameKey };
|
||||
|
||||
size_t version = ValidateDictionary<PrimitiveFunction>(dict, s_requiredDictionaryKeys, s_primitiveFunctionTypeValue, s_serializationVersion);
|
||||
|
||||
PrimitiveOpType op = PrimitiveOpType(dict[opKey].Value<std::size_t>());
|
||||
|
||||
// The hard requirement that the serialization depends on is that
|
||||
// new op type values are only added to the end of the list, after Combine.
|
||||
// This also applies to other enums (DataType, VariableKind, etc.)
|
||||
if (op > PrimitiveOpType::Combine)
|
||||
{
|
||||
LogicError("Unexpected variable '%ls':'%u' (%s).",
|
||||
opKey.c_str(),
|
||||
static_cast<std::underlying_type<CNTK::PrimitiveOpType>::type>(op),
|
||||
GetVersionsString<PrimitiveFunction>(s_serializationVersion, version).c_str());
|
||||
}
|
||||
|
||||
const auto& uid = dict[uidKey].Value<std::wstring>();
|
||||
const auto& name = dict[nameKey].Value<std::wstring>();
|
||||
auto attributes = dict[attributesKey].Value<Dictionary>();
|
||||
const auto& inputUids = dict[inputsKey].Value<vector<DictionaryValue>>();
|
||||
|
||||
std::vector<Variable> inputs;
|
||||
inputs.reserve(inputUids.size());
|
||||
|
||||
for (const auto& dictionaryValue : inputUids)
|
||||
{
|
||||
const auto& inputUid = dictionaryValue.Value<std::wstring>();
|
||||
if (uidToVariableMap.find(inputUid) == uidToVariableMap.end())
|
||||
{
|
||||
LogicError("There are no inputs corresponging to input uid = '%ls' "
|
||||
"(%s).", inputUid.c_str(), GetVersionsString<PrimitiveFunction>(s_serializationVersion, version).c_str());
|
||||
}
|
||||
inputs.push_back(uidToVariableMap.at(inputUid));
|
||||
}
|
||||
|
||||
return std::shared_ptr<PrimitiveFunction>(new PrimitiveFunction(op, inputs, std::move(attributes), name, uid),
|
||||
[](PrimitiveFunction* ptr) { delete ptr; });
|
||||
}
|
||||
}
|
|
@ -8,12 +8,9 @@
|
|||
#include "stdafx.h"
|
||||
#include "CNTKLibrary.h"
|
||||
#include "PrimitiveOpType.h"
|
||||
#include <iterator>
|
||||
#include "ComputationNetwork.h"
|
||||
#include "Utils.h"
|
||||
#include "ConvolveGeometry.h"
|
||||
#include "ConvolutionalNodes.h"
|
||||
#include "BackCompat.h"
|
||||
|
||||
namespace std
|
||||
{
|
||||
|
@ -648,256 +645,4 @@ namespace CNTK
|
|||
// a more meaningful message when trying to load a new model with a stale binary.
|
||||
static const size_t s_serializationVersion = 2;
|
||||
};
|
||||
|
||||
class CNTKBackPropState final : public BackPropState
|
||||
{
|
||||
public:
|
||||
CNTKBackPropState(const FunctionPtr& function, const DeviceDescriptor& computeDevice, const std::pair<Variable, int64_t>& evalTimeStamp)
|
||||
: BackPropState(function, computeDevice), m_evalTimeStamp(evalTimeStamp)
|
||||
{}
|
||||
|
||||
std::pair<Variable, int64_t> EvalTimeStamp() const
|
||||
{
|
||||
return m_evalTimeStamp;
|
||||
}
|
||||
|
||||
private:
|
||||
std::pair<Variable, int64_t> m_evalTimeStamp;
|
||||
};
|
||||
typedef std::shared_ptr<CNTKBackPropState> CNTKBackPropStatePtr;
|
||||
|
||||
class CompositeFunction;
|
||||
typedef std::shared_ptr<CompositeFunction> CompositeFunctionPtr;
|
||||
|
||||
class CompositeFunction final : public Function
|
||||
{
|
||||
friend class Function;
|
||||
friend class Trainer;
|
||||
friend class CompositeMinibatchSource;
|
||||
friend class PackedValue;
|
||||
|
||||
template <typename T, typename ...CtorArgTypes>
|
||||
friend inline std::shared_ptr<T> MakeSharedObject(CtorArgTypes&& ...ctorArgs);
|
||||
|
||||
friend void Internal::SaveAsLegacyModel(const FunctionPtr& rootFunction, const std::wstring& modelFile);
|
||||
|
||||
friend void ComputeInputPerDimMeansAndInvStdDevs(const MinibatchSourcePtr& minibatchSource,
|
||||
std::unordered_map<StreamInformation, std::pair<NDArrayViewPtr, NDArrayViewPtr>>& computedMeanAndInvStdDevs,
|
||||
const DeviceDescriptor& device /*= DeviceDescriptor::CPUDevice()*/);
|
||||
|
||||
static std::atomic<unsigned int> s_nextAutoGeneratedDynamicAxis;
|
||||
|
||||
static const std::wstring CompositeFunctionOpName;
|
||||
|
||||
public:
|
||||
static const std::wstring InternalDefaultDynamicAxisName;
|
||||
static const std::wstring InternalNoSequenceAxisName;
|
||||
|
||||
static Axis NextAutoGeneratedDynamicAxis()
|
||||
{
|
||||
static const std::wstring s_autoGeneratedDynamicAxisNamePrefix = L"autoGeneratedDynamicAxis_";
|
||||
return Axis(s_autoGeneratedDynamicAxisNamePrefix + std::to_wstring(s_nextAutoGeneratedDynamicAxis++));
|
||||
}
|
||||
|
||||
public:
|
||||
static CompositeFunctionPtr Create(const FunctionPtr& rootFunction, const std::wstring& name = L"", const std::wstring& uid = L"")
|
||||
{
|
||||
std::unordered_set<FunctionPtr> visitedFunctions;
|
||||
|
||||
// Call Collect to get the set of all functions in the graph
|
||||
Collect(rootFunction, visitedFunctions);
|
||||
|
||||
return MakeSharedObject<CompositeFunction>(rootFunction, std::move(visitedFunctions), name, uid);
|
||||
}
|
||||
|
||||
virtual BackPropStatePtr Forward(const std::unordered_map<Variable, ValuePtr>& arguments,
|
||||
std::unordered_map<Variable, ValuePtr>& outputs,
|
||||
const DeviceDescriptor& computeDevice,
|
||||
const std::unordered_set<Variable>& outputsToRetainBackwardStateFor) override;
|
||||
|
||||
virtual void Backward(const BackPropStatePtr& state,
|
||||
const std::unordered_map<Variable, ValuePtr>& rootGradientValues,
|
||||
std::unordered_map<Variable, ValuePtr>& backPropagatedGradientValuesForInputs) override;
|
||||
|
||||
virtual Dictionary Serialize() const override;
|
||||
|
||||
virtual size_t CurrentVersion() const override { return s_serializationVersion; }
|
||||
|
||||
static FunctionPtr Deserialize(const Dictionary& dictionary, const CNTK::DeviceDescriptor& device);
|
||||
|
||||
virtual const std::wstring& OpName() override
|
||||
{
|
||||
return CompositeFunctionOpName;
|
||||
}
|
||||
|
||||
template <typename FunctionType>
|
||||
static void Traverse(const FunctionPtr& rootFunction, const FunctionType& functor)
|
||||
{
|
||||
std::unordered_set<FunctionPtr> visitedFunctions;
|
||||
Traverse(rootFunction, visitedFunctions, functor);
|
||||
}
|
||||
|
||||
// Recursively traverses the Function graph underlying the 'rootFunction' invoking the provided functor for all visited nodes in the graph.
|
||||
template <typename FunctionType>
|
||||
static void Traverse(const FunctionPtr& rootFunction, std::unordered_set<FunctionPtr>& visitedFunctions, const FunctionType& functor)
|
||||
{
|
||||
visitedFunctions.insert(rootFunction);
|
||||
functor(rootFunction);
|
||||
|
||||
std::vector<Variable> rootFunctionInputs = rootFunction->Inputs();
|
||||
for (const auto& rootInput : rootFunctionInputs)
|
||||
{
|
||||
if (rootInput.IsOutput() && visitedFunctions.find(rootInput.Owner()) == visitedFunctions.end())
|
||||
{
|
||||
const auto& function = rootInput.Owner();
|
||||
Traverse(function, visitedFunctions, functor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
virtual void ReplacePlaceholdersInPlace(const std::unordered_map<Variable, Variable>& placeholderReplacements,
|
||||
std::unordered_set<const Function*>& visitedFunctions,
|
||||
std::unordered_set<Variable>& replacedPlaceholders) override;
|
||||
|
||||
CompositeFunction(const FunctionPtr& rootFunction, std::unordered_set<FunctionPtr>&& allPrimitiveFunctions, const std::wstring& name, const std::wstring& uid = Internal::GenerateUid(L"CompositeFunction"))
|
||||
: Function({}, rootFunction->Outputs(), Dictionary(), rootFunction, name, uid),
|
||||
m_allPrimitiveFunctions(std::move(allPrimitiveFunctions)), m_networkMatricesAllocated(false)
|
||||
{}
|
||||
|
||||
std::vector<Variable> DetermineInputs() const
|
||||
{
|
||||
const auto& root = RootFunction();
|
||||
std::unordered_set<FunctionPtr> visitedFunctions;
|
||||
return DetermineInputs(root, visitedFunctions);
|
||||
}
|
||||
|
||||
// Recursively traverses the Function graph and populates the provided set of functions.
|
||||
static void Collect(const FunctionPtr& rootFunction, std::unordered_set<FunctionPtr>& functions)
|
||||
{
|
||||
// Call Traverse to get the set of all functions in the graph
|
||||
Traverse(rootFunction, functions, [](const FunctionPtr& f){});
|
||||
}
|
||||
|
||||
// Recursively traverses the Function graph underlying the 'rootFunction' to determine all the leaves (aka inputs) of the graph
|
||||
static std::vector<Variable> DetermineInputs(const FunctionPtr& rootFunction, std::unordered_set<FunctionPtr>& visitedFunctions)
|
||||
{
|
||||
vector<FunctionPtr> functions;
|
||||
std::vector<Variable> inputs;
|
||||
std::unordered_set<Variable> uniqueInputs;
|
||||
Traverse(rootFunction, visitedFunctions, [&inputs, &uniqueInputs](const FunctionPtr& f){
|
||||
std::vector<Variable> functionInputs = f->Inputs();
|
||||
for (auto input : functionInputs)
|
||||
{
|
||||
if (!input.IsOutput() && uniqueInputs.find(input) == uniqueInputs.end())
|
||||
{
|
||||
inputs.push_back(input);
|
||||
uniqueInputs.insert(input);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
return inputs;
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
Microsoft::MSR::CNTK::ComputationNetworkPtr GetComputationNetwork(const DeviceDescriptor& device, const std::unordered_set<Variable>& backpropRoots, bool allocateNetworkMatrices);
|
||||
|
||||
template <typename ElementType>
|
||||
static Microsoft::MSR::CNTK::ComputationNodeBasePtr CreateComputationNode(const Variable& variable,
|
||||
PrimitiveFunction* primitiveFunction,
|
||||
const std::vector<std::shared_ptr<Microsoft::MSR::CNTK::ComputationNode<ElementType>>>& inputNodes,
|
||||
Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
|
||||
std::unordered_map<Variable, Microsoft::MSR::CNTK::ComputationNodeBasePtr>& variableToNodeMap);
|
||||
|
||||
template <typename ElementType>
|
||||
static Microsoft::MSR::CNTK::ComputationNodeBasePtr GetOutputVariableNode(const Variable& variable,
|
||||
Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
|
||||
Microsoft::MSR::CNTK::ComputationNetworkBuilder<ElementType>& builder,
|
||||
std::unordered_map<Variable, Microsoft::MSR::CNTK::ComputationNodeBasePtr>& variableToNodeMap,
|
||||
std::unordered_map<Variable, bool>& isVariableRootMap);
|
||||
|
||||
template <typename ElementType>
|
||||
static Microsoft::MSR::CNTK::ComputationNodeBasePtr GetNode(const Variable& variable, Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
|
||||
Microsoft::MSR::CNTK::ComputationNetworkBuilder<ElementType>& builder,
|
||||
std::unordered_map<Variable, Microsoft::MSR::CNTK::ComputationNodeBasePtr>& variableToNodeMap,
|
||||
std::unordered_map<Variable, bool>& isVariableRootMap);
|
||||
|
||||
template <typename ElementType>
|
||||
static void PopulateComputationNodeValue(const std::pair<Variable, ValuePtr>& variableValue, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode);
|
||||
void PopulateNetworkInputs(const std::unordered_map<Variable, ValuePtr>& arguments);
|
||||
|
||||
template <typename ElementType>
|
||||
static void PopulateComputationNodeGradient(const std::pair<Variable, ValuePtr>& variableGradient, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode);
|
||||
void PopulateNetworkGradients(const std::unordered_map<Variable, ValuePtr>& gradients);
|
||||
|
||||
static void GetNodeOutputOrGradient(Variable var, ValuePtr& varValue, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode, bool getGradient);
|
||||
void GetNetworkOutputs(std::unordered_map<Variable, ValuePtr>& outputs);
|
||||
void GetNetworkGradients(std::unordered_map<Variable, ValuePtr>& gradients);
|
||||
|
||||
template <typename ElementType>
|
||||
static std::pair<std::shared_ptr<const Microsoft::MSR::CNTK::Matrix<ElementType>>, Microsoft::MSR::CNTK::MBLayoutPtr> GetCNTKImplMatrixAndMBLayoutFromValueObject(Variable var, const ValuePtr& value);
|
||||
|
||||
template <typename ElementType>
|
||||
static ValuePtr GetValueObjectFromCNTKImplMatrixAndMBLayout(const NDShape& sampleShape, const Microsoft::MSR::CNTK::Matrix<ElementType>& matrix, const Microsoft::MSR::CNTK::MBLayoutPtr& layout, bool readOnly = true);
|
||||
template <typename ElementType>
|
||||
static ValuePtr GetValueObjectFromCNTKImplMatrixAndMBLayout(Variable var, const Microsoft::MSR::CNTK::Matrix<ElementType>& matrix, const Microsoft::MSR::CNTK::MBLayoutPtr& layout, bool readOnly = true);
|
||||
|
||||
const std::vector<Variable>& GetArgumentDependencies(const Variable& output);
|
||||
|
||||
private:
|
||||
|
||||
// Set of all primitive functions in the graph underlying 'this' Function. Also keeps the primitive Function objects alive
|
||||
// by holding strong references to them
|
||||
std::unordered_set<FunctionPtr> m_allPrimitiveFunctions;
|
||||
|
||||
// A map from Variable objects to ComputationNode objects in the ComputationNetwork instance that implements 'this' Composite Function
|
||||
std::unordered_map<Variable, Microsoft::MSR::CNTK::ComputationNodeBasePtr> m_variableToNodeMap;
|
||||
|
||||
// A map that tells whether a Variable in the graph underlying 'this' Function is a root of the graph
|
||||
std::unordered_map<Variable, bool> m_isVariableRootMap;
|
||||
|
||||
Microsoft::MSR::CNTK::ComputationNetworkPtr m_computationNetwork;
|
||||
|
||||
// The backpropRoots sepecified in the most recent 'Forward' call on 'this' Function.
|
||||
// This indicates for which of its roots has 'this' Function retained required intermediate
|
||||
// states from the previos Forward call to be able to backpropagate gradients backwards from in
|
||||
// the next 'Backward' call.
|
||||
std::unordered_set<Variable> m_currentBackpropRoots;
|
||||
|
||||
std::unordered_map<Variable, std::vector<Variable>> m_perOutputVarArgumentDependencies;
|
||||
|
||||
bool m_networkMatricesAllocated;
|
||||
|
||||
std::unordered_map<Parameter, size_t> m_lastRecordedParameterValueTimeStamps;
|
||||
|
||||
static const size_t s_serializationVersion = 1;
|
||||
};
|
||||
|
||||
inline std::vector<CNTK::Axis> DynamicAxesFromInternalDynamicAxisName(const std::wstring& internalDynamicAxisName)
|
||||
{
|
||||
std::vector<CNTK::Axis> inputVarDynamicAxes;
|
||||
if (internalDynamicAxisName.substr(0, CNTK::CompositeFunction::InternalDefaultDynamicAxisName.length()) == CNTK::CompositeFunction::InternalDefaultDynamicAxisName)
|
||||
inputVarDynamicAxes = { CNTK::Axis::DefaultDynamicAxis(), CNTK::Axis::DefaultBatchAxis() };
|
||||
else if (internalDynamicAxisName.substr(0, CNTK::CompositeFunction::InternalNoSequenceAxisName.length()) == CNTK::CompositeFunction::InternalNoSequenceAxisName)
|
||||
inputVarDynamicAxes = { CNTK::Axis::DefaultBatchAxis() };
|
||||
else
|
||||
inputVarDynamicAxes = { CNTK::Axis(internalDynamicAxisName), CNTK::Axis::DefaultBatchAxis() };
|
||||
|
||||
return inputVarDynamicAxes;
|
||||
}
|
||||
|
||||
// Construct the dynamic axis name to be used internally for the CNTK InputNodes
|
||||
inline std::wstring InternalDynamicAxisNameFromDynamicAxes(const std::vector<Axis>& dynamicAxes)
|
||||
{
|
||||
if (dynamicAxes.empty())
|
||||
LogicError("Empty dynamic axes set");
|
||||
|
||||
if (dynamicAxes == std::vector<Axis>({ Axis::DefaultBatchAxis() }))
|
||||
return CompositeFunction::InternalNoSequenceAxisName;
|
||||
else if (dynamicAxes == std::vector<Axis>({ Axis::DefaultDynamicAxis(), Axis::DefaultBatchAxis() }))
|
||||
return CompositeFunction::InternalDefaultDynamicAxisName;
|
||||
else
|
||||
return dynamicAxes[0].Name();
|
||||
}
|
||||
}
|
|
@ -6,7 +6,6 @@
|
|||
#include "stdafx.h"
|
||||
#include "CNTKLibrary.h"
|
||||
#include "Utils.h"
|
||||
#include "Function.h"
|
||||
#include "Serialization.h"
|
||||
|
||||
namespace
|
||||
|
|
|
@ -13,8 +13,8 @@
|
|||
#include "CNTKLibrary.h"
|
||||
#include "Utils.h"
|
||||
#include "Serialization.h"
|
||||
#include "Function.h"
|
||||
#include <fcntl.h>
|
||||
#include "PrimitiveFunction.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
|
|
@ -10,9 +10,9 @@
|
|||
#endif
|
||||
|
||||
#include "CNTKLibrary.h"
|
||||
#include "CompositeFunction.h"
|
||||
#include "Utils.h"
|
||||
#include "Value.h"
|
||||
#include "Function.h"
|
||||
|
||||
namespace CNTK
|
||||
{
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
#include "stdafx.h"
|
||||
#include "CNTKLibrary.h"
|
||||
#include "Sequences.h"
|
||||
#include "TensorView.h"
|
||||
#include "Utils.h"
|
||||
|
||||
namespace CNTK
|
||||
|
|
|
@ -5,8 +5,8 @@
|
|||
|
||||
#include "stdafx.h"
|
||||
#include "CNTKLibrary.h"
|
||||
#include "CompositeFunction.h"
|
||||
#include "Serialization.h"
|
||||
#include "Function.h"
|
||||
#include "InputAndParamNodes.h"
|
||||
|
||||
namespace CNTK
|
||||
|
|
Загрузка…
Ссылка в новой задаче