CNTK v2 library: Added support for loading legacy v1 format models and saving models in v1 format

This commit is contained in:
Amit Agarwal 2016-07-17 16:44:44 -07:00
Родитель a67273394e
Коммит 9ba93ab84e
16 изменённых файлов: 407 добавлений и 65 удалений

Просмотреть файл

@ -368,6 +368,7 @@ SEQUENCE_TRAINING_LIB_SRC +=\
endif
CNTKLIBRARY_SRC =\
$(SOURCEDIR)/CNTKv2LibraryDll/BackCompat.cpp \
$(SOURCEDIR)/CNTKv2LibraryDll/Common.cpp \
$(SOURCEDIR)/CNTKv2LibraryDll/Function.cpp \
$(SOURCEDIR)/CNTKv2LibraryDll/NDArrayView.cpp \

Просмотреть файл

@ -758,6 +758,13 @@ namespace CNTK
friend struct std::hash;
public:
///
/// Create an 'Input' Variable.
///
Variable(const NDShape& shape, CNTK::DataType dataType, const wchar_t* name = L"")
: Variable(shape, dataType, std::wstring(name))
{}
///
/// Create an 'Input' Variable.
///
@ -920,6 +927,11 @@ namespace CNTK
return first.m_dataFields == second.m_dataFields;
}
inline bool operator!=(const Variable& first, const Variable& second)
{
return !(first == second);
}
///
/// Denotes Parameter inputs of a Function.
///
@ -1396,8 +1408,19 @@ namespace CNTK
/// E.g. When creating a classification model, typically the CrossEntropy loss Function and the ClassificationError Function comprise the two roots
/// of the computation graph which can be "Combine"d to create a single Function with 2 outputs; viz. CrossEntropy loss and ClassificationError output.
///
CNTK_API FunctionPtr Combine(const std::initializer_list<FunctionPtr>& operands, const std::wstring& name = L"");
CNTK_API FunctionPtr Combine(const std::vector<FunctionPtr>& operands, const std::wstring& name = L"");
///
/// Load a legacy CNTK v1 format model
///
template <typename ElementType>
CNTK_API FunctionPtr LoadLegacyModel(const std::wstring& modelFile, const DeviceDescriptor& computeDevice = DeviceDescriptor::DefaultDevice());
///
/// Save a Composite Function instance to a file in CNTK legacy model format
///
template <typename ElementType>
CNTK_API void SaveAsLegacyModel(const FunctionPtr& rootFunction, const std::wstring& modelFile);
///
/// A serializable value represents one of:
/// a) Boolean

Просмотреть файл

@ -0,0 +1,199 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
#include "stdafx.h"
#include "CNTKLibrary.h"
#include "Function.h"
#include "ComputationNetworkBuilder.h"
#include "Utils.h"
#include "ComputationNode.h"
#include "InputAndParamNodes.h"
#include "NonlinearityNodes.h"
#include "LinearAlgebraNodes.h"
#include "RecurrentNodes.h"
#include "EvaluationNodes.h"
#include "TrainingNodes.h"
using namespace Microsoft::MSR::CNTK;
namespace CNTK
{
template <typename ElementType>
Variable GetVariable(const ComputationNodeBasePtr& node,
std::unordered_map<ComputationNodeBasePtr, Variable>& nodeToVariableMap,
std::unordered_map<Placeholder, Variable>& placeholderReplacements,
std::unordered_set<FunctionPtr>& allPrimitiveFunctions)
{
auto iter = nodeToVariableMap.find(node);
if (iter != nodeToVariableMap.end())
return iter->second;
Variable var;
NDShape varShape = AsNDShape(node->GetSampleLayout());
// The CNTK sample layouts may have trailing axes with dimension size of 1 which are automatically
// added when converting from NDShape to CNTK internal TensorShapes and are not present in the original
// shapes specified by the user. These should be truncated.
if (varShape.NumAxes() <= 2)
{
size_t numTrailingDimsToRemove = 0;
for (int i = varShape.NumAxes() - 1; i >= 0; --i)
{
if (varShape[i] == 1)
numTrailingDimsToRemove++;
else
break;
}
varShape = varShape.SubShape(0, varShape.NumAxes() - numTrailingDimsToRemove);
}
if (node->IsLeaf())
{
if (node->Is<InputValueBase<ElementType>>())
{
auto inputNode = node->As<InputValueBase<ElementType>>();
bool isSparse = node->Is<SparseInputValue<ElementType>>();
if (node->HasMBLayout())
{
// TODO: Currently only default dynamic axis is supported
const std::wstring defaultCNTKDynamicAxisName = L"";
if (inputNode->GetRequestedDynamicAxis() != defaultCNTKDynamicAxisName)
LogicError("Found dynamic axis named '%S' while currently only default dynamic axis named '%S' is supported!", node->GetMBLayout()->GetAxisName(), defaultCNTKDynamicAxisName);
var = Variable(varShape, isSparse, AsDataType<ElementType>(), node->GetLearningRateMultiplier() != 0, node->GetName());
}
else
{
// TODO: Allow creating inputs without a dynamic axis
LogicError("Found InputNode with no dynamic axis which is currently unsupported");
}
}
else if (node->Is<LearnableParameter<ElementType>>())
{
auto& matrix = node->As<ComputationNode<ElementType>>()->Value();
auto tensorView = new TensorView<ElementType>(std::make_shared<Matrix<ElementType>>(matrix.AsReference()), node->GetSampleLayout());
NDArrayViewPtr parameterValue = MakeSharedObject<NDArrayView>(AsDataType<ElementType>(), AsDeviceDescriptor(matrix.GetDeviceId()), AsStorageFormat(matrix.GetFormat()), varShape, false, tensorView);
var = Parameter(parameterValue, node->GetName());
}
else
LogicError("CNTK::LoadLegacyModel: Unsupported legacy CNTK node named '%S'", node->NodeName().c_str());
}
else
{
// This is a non-leaf node and maps to a primitive Function
auto placeholderVar = Placeholder(varShape);
nodeToVariableMap[node] = placeholderVar;
std::vector<Variable> inputVars(node->GetNumInputs());
for (size_t i = 0; i < inputVars.size(); ++i)
{
inputVars[i] = GetVariable<ElementType>(node->Input(i), nodeToVariableMap, placeholderReplacements, allPrimitiveFunctions);
if (inputVars[i].IsPlaceholder())
placeholderReplacements[Placeholder(inputVars[i])] = Variable();
}
PrimitiveOpType opType;
Dictionary primitiveFunctionConfigParameters;
if (node->OperationName() == OperationNameOf(TanhNode))
opType = PrimitiveOpType::Tanh;
else if (node->OperationName() == OperationNameOf(SigmoidNode))
opType = PrimitiveOpType::Sigmoid;
else if (node->OperationName() == OperationNameOf(ExpNode))
opType = PrimitiveOpType::Exp;
else if (node->OperationName() == OperationNameOf(TimesNode))
opType = PrimitiveOpType::Times;
else if (node->OperationName() == OperationNameOf(PlusNode))
opType = PrimitiveOpType::Plus;
else if (node->OperationName() == OperationNameOf(PastValueNode))
{
if (inputVars.size() == 1)
{
auto initialStateVar = Constant({}, node->As<PastValueNode<ElementType>>()->InitialActivationValue(), AsDeviceDescriptor(node->GetDeviceId()));
inputVars.insert(inputVars.begin(), initialStateVar);
}
primitiveFunctionConfigParameters[L"stepSize"] = DictionaryValue((size_t)node->As<PastValueNode<ElementType>>()->TimeStep());
opType = PrimitiveOpType::PastValue;
}
else if (node->OperationName() == OperationNameOf(FutureValueNode))
{
if (inputVars.size() == 1)
{
auto initialStateVar = Constant({}, node->As<FutureValueNode<ElementType>>()->InitialActivationValue(), AsDeviceDescriptor(node->GetDeviceId()));
inputVars.insert(inputVars.begin(), initialStateVar);
}
primitiveFunctionConfigParameters[L"stepSize"] = DictionaryValue((size_t)node->As<FutureValueNode<ElementType>>()->TimeStep());
opType = PrimitiveOpType::FutureValue;
}
else if (node->OperationName() == OperationNameOf(CrossEntropyWithSoftmaxNode))
{
std::swap(inputVars[0], inputVars[1]);
opType = PrimitiveOpType::CrossEntropyWithSoftmax;
}
else if (node->OperationName() == OperationNameOf(ErrorPredictionNode))
{
std::swap(inputVars[0], inputVars[1]);
opType = PrimitiveOpType::ClassificationError;
}
else if (node->OperationName() == OperationNameOf(ElementTimesNode))
opType = PrimitiveOpType::ElementTimes;
else if (node->OperationName() == OperationNameOf(SumElementsNode))
opType = PrimitiveOpType::ReduceSum;
else
LogicError("Unsupported ComputationNode with OperationName='%S' found when loading legacy CNTK model", node->OperationName().c_str());
FunctionPtr primitiveFunction = MakeSharedObject<PrimitiveFunction>(opType, inputVars, std::move(primitiveFunctionConfigParameters), node->GetName());
allPrimitiveFunctions.insert(primitiveFunction);
var = primitiveFunction->Output();
if (placeholderReplacements.find(placeholderVar) != placeholderReplacements.end())
placeholderReplacements[placeholderVar] = var;
}
nodeToVariableMap[node] = var;
return var;
}
template <typename ElementType>
FunctionPtr LoadLegacyModel(const std::wstring& modelFile, const DeviceDescriptor& computeDevice /*= DeviceDescriptor::DefaultDevice()*/)
{
ComputationNetworkPtr net = make_shared<ComputationNetwork>(AsCNTKImplDeviceId(computeDevice));
net->Load<ElementType>(modelFile);
// Now traverse the model and construct the Function graph
std::unordered_map<ComputationNodeBasePtr, Variable> nodeToVariableMap;
std::unordered_map<Placeholder, Variable> placeholderReplacements;
std::unordered_set<FunctionPtr> allPrimitiveFunctions;
std::vector<FunctionPtr> rootFunctions;
auto& networkRoots = net->RootNodes();
for (auto& rootNode : networkRoots)
{
if (rootNode->IsLeaf())
continue;
rootFunctions.push_back(GetVariable<ElementType>(rootNode, nodeToVariableMap, placeholderReplacements, allPrimitiveFunctions).Owner());
}
auto rootComposite = Combine(rootFunctions);
rootComposite->ReplacePlaceholders(placeholderReplacements);
return rootComposite;
}
template <typename ElementType>
void SaveAsLegacyModel(const FunctionPtr& rootFunction, const std::wstring& modelFile)
{
CompositeFunction* compositeFunction = dynamic_cast<CompositeFunction*>(rootFunction.get());
if (compositeFunction == nullptr)
InvalidArgument("Primitive (aka non-composite) Function instances cannot be saved");
auto computationNetwork = compositeFunction->GetComputationNetwork<ElementType>(DeviceDescriptor::CPUDevice(), {});
computationNetwork->Save(modelFile);
}
// Template instantiations
template CNTK_API FunctionPtr LoadLegacyModel<float>(const std::wstring& modelFile, const DeviceDescriptor& computeDevice /*= DeviceDescriptor::DefaultDevice()*/);
template CNTK_API FunctionPtr LoadLegacyModel<double>(const std::wstring& modelFile, const DeviceDescriptor& computeDevice /*= DeviceDescriptor::DefaultDevice()*/);
template CNTK_API void SaveAsLegacyModel<float>(const FunctionPtr& rootFunction, const std::wstring& modelFile);
template CNTK_API void SaveAsLegacyModel<double>(const FunctionPtr& rootFunction, const std::wstring& modelFile);
}

Просмотреть файл

@ -134,6 +134,7 @@
<ClInclude Include="targetver.h" />
</ItemGroup>
<ItemGroup>
<ClCompile Include="BackCompat.cpp" />
<ClCompile Include="Common.cpp" />
<ClCompile Include="dllmain.cpp">
<CompileAsManaged>false</CompileAsManaged>

Просмотреть файл

@ -11,6 +11,7 @@
<ClCompile Include="Utils.cpp" />
<ClCompile Include="NDMask.cpp" />
<ClCompile Include="Learner.cpp" />
<ClCompile Include="BackCompat.cpp" />
</ItemGroup>
<ItemGroup>
<ClInclude Include="stdafx.h" />

Просмотреть файл

@ -126,7 +126,14 @@ namespace CNTK
}
else if (variable.IsInput())
{
// TODO: Specify dynamic axis
// TODO: Support inputs with > 1 dynamic axes
if (variable.DynamicAxes().size() != 1)
LogicError("Currently only Input variables with one dynamic axis are supported");
auto dynamicAxis = variable.DynamicAxes()[0];
if (dynamicAxis != Axis::DefaultDynamicAxis())
LogicError("Currently only Input variables with DefaultDynamicAxis are supported");
if (IsSparseInput(variable))
computationNodePtr = builder.CreateSparseInputNode(variable.Name(), AsTensorShape(variable.Shape()));
else
@ -872,7 +879,7 @@ namespace CNTK
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::Tanh, std::vector<Variable>({ operand }), Dictionary(), name), name);
}
FunctionPtr Combine(const std::initializer_list<FunctionPtr>& operands, const std::wstring& name/* = L""*/)
FunctionPtr Combine(const std::vector<FunctionPtr>& operands, const std::wstring& name/* = L""*/)
{
std::unordered_set<FunctionPtr> uniqueOperands;
std::vector<Variable> inputs;

Просмотреть файл

@ -220,7 +220,7 @@ namespace CNTK
{
assert(inputs.size() == 2);
if (inputs[0].Shape().NumAxes() > 1)
if ((inputs[0].Shape().NumAxes() > 2) || ((inputs[0].Shape().NumAxes() > 1) && (inputs[0].Shape()[1] != 1)))
InvalidArgument("The shape of input operands for the %s operation should have at most one axis", PrimitiveOpTypeName(op));
auto predictionShape = inputs[0].Shape();
@ -292,6 +292,9 @@ namespace CNTK
template <typename T, typename ...CtorArgTypes>
friend inline std::shared_ptr<T> MakeSharedObject(CtorArgTypes&& ...ctorArgs);
template <typename ElementType>
friend void SaveAsLegacyModel(const FunctionPtr& rootFunction, const std::wstring& modelFile);
public:
static CompositeFunctionPtr Create(const FunctionPtr& rootFunction, const std::wstring& name = L"")
{

Просмотреть файл

@ -78,6 +78,18 @@ namespace CNTK
LogicError("Unknown DataType");
}
inline NDShape AsNDShape(const Microsoft::MSR::CNTK::TensorShape& tensorShape)
{
// The TensorShape should be flattenable to 1D
for (size_t i = 1; i < tensorShape.GetRank(); ++i)
{
if (!tensorShape.CanFlatten(i))
InvalidArgument("AsNDShape() can only be called for TensorShapes that can be flattened to 1D");
}
return std::vector<size_t>(tensorShape.GetDims().begin(), tensorShape.GetDims().end());
}
inline Microsoft::MSR::CNTK::TensorShape AsTensorShape(const NDShape& viewShape)
{
const size_t maxNumAxesSupportedByTensorView = 12;

Просмотреть файл

@ -522,6 +522,8 @@ public:
}
const std::vector<ComputationNodeBasePtr>& RootNodes() const { return m_allRoots; }
// these are specified as such by the user
const std::vector<ComputationNodeBasePtr>& FeatureNodes() const { return m_featureNodes ; }
const std::vector<ComputationNodeBasePtr>& LabelNodes() const { return m_labelNodes ; }

Просмотреть файл

@ -38,7 +38,8 @@
#define CNTK_MODEL_VERSION_7 7 // ElemType tag in model file
#define CNTK_MODEL_VERSION_8 8 // DynamicAxis for inputs
#define CNTK_MODEL_VERSION_9 9 // Transpose flag in ConvolutionNode to support deconvolution.
#define CURRENT_CNTK_MODEL_VERSION CNTK_MODEL_VERSION_9
#define CNTK_MODEL_VERSION_10 10 // Learning rate multiplier for input nodes.
#define CURRENT_CNTK_MODEL_VERSION CNTK_MODEL_VERSION_10
extern bool g_shareNodeValueMatrices;

Просмотреть файл

@ -162,7 +162,7 @@ class InputValueBase : public ComputationNode<ElemType>, public NumInputs<0>, pu
typedef ComputationNode<ElemType> Base;
UsingComputationNodeMembers;
void Init(const TensorShape& sampleLayout, bool isSparse, const std::wstring axisName)
void Init(const TensorShape& sampleLayout, bool isSparse, const std::wstring axisName, float learningRateMultiplier = 0)
{
m_isSparse = isSparse;
MarkValueNonSharable();
@ -171,7 +171,7 @@ class InputValueBase : public ComputationNode<ElemType>, public NumInputs<0>, pu
SetDims(sampleLayout, HasMBLayout()); // also called when reloading a file. Then we have an MBLayout, otherwise not yet
UpdateFunctionValuesSize(); // we must allocate the matrix so that the readers get objects with valid row dimensions (some readers expect that)
SetLearningRateMultiplier(0);
SetLearningRateMultiplier(learningRateMultiplier);
m_dynamicAxisNodeName = axisName;
}
@ -225,9 +225,9 @@ protected:
Init(ImageDimensions::AsTensorShape(configp->Get(L"imageWidth"), configp->Get(L"imageHeight"), configp->Get(L"imageChannels"), ImageLayoutKindFrom(configp->Get(L"imageLayout"))), isSparse, axisName);
}
public:
virtual const std::wstring GetRequestedDynamicAxis() const { return m_dynamicAxisNodeName; }
public:
virtual void Save(File& fstream) const override
{
Base::Save(fstream);
@ -239,6 +239,8 @@ public:
unsigned int nrAxes = 1;
fstream << nrAxes;
fstream << m_dynamicAxisNodeName;
fstream << m_learningRateMultiplier;
}
virtual void Load(File& fstream, size_t modelVersion) override
@ -268,7 +270,12 @@ public:
}
else
m_dynamicAxisNodeName = L""; // Use default
Init(sampleLayout, m_isSparse, m_dynamicAxisNodeName);
float learningRateMultiplier = 0;
if (modelVersion >= CNTK_MODEL_VERSION_10)
fstream >> learningRateMultiplier;
Init(sampleLayout, m_isSparse, m_dynamicAxisNodeName, learningRateMultiplier);
}
// InputValue must not resize its inputs because that might destroy it. It should already have the correct size.

Просмотреть файл

@ -464,6 +464,9 @@ public:
LogicError("Unrecognized direction in DelayedValueNodeBase");
}
int TimeStep() const { return m_timeStep; }
ElemType InitialActivationValue() const { return m_initialActivationValue; }
protected:
ElemType m_initialActivationValue; // starting value for hidden activation vector at boundary
Matrix<ElemType> m_delayedValue; // saves the activation of the previous step that this node points to

Просмотреть файл

@ -2,6 +2,7 @@
#include <exception>
#include <algorithm>
#include "CNTKLibrary.h"
static const double relativeTolerance = 0.001f;
static const double absoluteTolerance = 0.000001f;
@ -18,3 +19,53 @@ inline void FloatingPointVectorCompare(const std::vector<ElementType>& first, co
throw std::runtime_error(message);
}
}
#pragma warning(push)
#pragma warning(disable: 4996)
template <typename ElementType>
inline void SaveAndReloadModel(CNTK::FunctionPtr& functionPtr, const std::vector<CNTK::Variable*>& variables, const CNTK::DeviceDescriptor& device)
{
static std::wstring s_tempModelPath = L"feedForward.net";
if ((_wunlink(s_tempModelPath.c_str()) != 0) && (errno != ENOENT))
RuntimeError("Error deleting file '%ls': %s", s_tempModelPath.c_str(), strerror(errno));
std::unordered_map<std::wstring, Variable*> inputVarNames;
std::unordered_map<std::wstring, Variable*> outputVarNames;
for (auto varPtr : variables)
{
auto retVal = varPtr->IsOutput() ? outputVarNames.insert({ varPtr->Owner()->Name(), varPtr }) : inputVarNames.insert({ varPtr->Name(), varPtr });
if (!retVal.second)
RuntimeError("SaveAndReloadModel: Multiple variables having same name cannot be restored after save and reload");
}
SaveAsLegacyModel<ElementType>(functionPtr, s_tempModelPath);
functionPtr = LoadLegacyModel<ElementType>(s_tempModelPath, device);
if (_wunlink(s_tempModelPath.c_str()) != 0)
RuntimeError("Error deleting file '%ls': %s", s_tempModelPath.c_str(), strerror(errno));
auto inputs = functionPtr->Inputs();
for (auto inputVarInfo : inputVarNames)
{
auto newInputVar = *(std::find_if(inputs.begin(), inputs.end(), [inputVarInfo](const Variable& var) {
return (var.Name() == inputVarInfo.first);
}));
*(inputVarInfo.second) = newInputVar;
}
auto outputs = functionPtr->Outputs();
for (auto outputVarInfo : outputVarNames)
{
auto newOutputVar = *(std::find_if(outputs.begin(), outputs.end(), [outputVarInfo](const Variable& var) {
return (var.Owner()->Name() == outputVarInfo.first);
}));
*(outputVarInfo.second) = newOutputVar;
}
}
#pragma warning(pop)

Просмотреть файл

@ -18,7 +18,13 @@ FunctionPtr FullyConnectedDNNLayer(Variable input, size_t outputDim, const Devic
return nonLinearity(plusFunction);
}
FunctionPtr FullyConnectedFeedForwardClassifierNet(Variable input, size_t numOutputClasses, size_t hiddenLayerDim, size_t numHiddenLayers, const DeviceDescriptor& device, const std::function<FunctionPtr(const FunctionPtr&)>& nonLinearity)
FunctionPtr FullyConnectedFeedForwardClassifierNet(Variable input,
size_t numOutputClasses,
size_t hiddenLayerDim,
size_t numHiddenLayers,
const DeviceDescriptor& device,
const std::function<FunctionPtr(const FunctionPtr&)>& nonLinearity,
const std::wstring& outputName)
{
assert(numHiddenLayers >= 1);
auto classifierRoot = FullyConnectedDNNLayer(input, hiddenLayerDim, device, nonLinearity);
@ -26,11 +32,12 @@ FunctionPtr FullyConnectedFeedForwardClassifierNet(Variable input, size_t numOut
classifierRoot = FullyConnectedDNNLayer(classifierRoot, hiddenLayerDim, device, nonLinearity);
auto outputTimesParam = Parameter(NDArrayView::RandomUniform<float>({ numOutputClasses, hiddenLayerDim }, -0.5, 0.5, 1, device));
classifierRoot = Times(outputTimesParam, classifierRoot);
return classifierRoot;
return Times(outputTimesParam, classifierRoot, outputName);
}
void TestFeedForwardNetworkCreation(const DeviceDescriptor& device)
std::wstring s_tempModelPath = L"feedForward.net";
void TestFeedForwardNetworkCreation(const DeviceDescriptor& device, bool testSaveAndReLoad)
{
using namespace std::placeholders;
@ -39,14 +46,17 @@ void TestFeedForwardNetworkCreation(const DeviceDescriptor& device)
const size_t numHiddenLayers = 6;
const size_t hiddenLayersDim = 2048;
Variable inputVar({ inputDim }, DataType::Float, L"Features");
auto classifierOutputFunction = FullyConnectedFeedForwardClassifierNet(inputVar, numOutputClasses, hiddenLayersDim, numHiddenLayers, device, std::bind(Sigmoid, _1, L""));
Variable inputVar({ inputDim }, DataType::Float, L"features");
auto classifierOutputFunction = FullyConnectedFeedForwardClassifierNet(inputVar, numOutputClasses, hiddenLayersDim, numHiddenLayers, device, std::bind(Sigmoid, _1, L""), L"classifierOutput");
Variable classifierOutput = classifierOutputFunction;
Variable labelsVar({ numOutputClasses }, DataType::Float, L"Labels");
auto trainingLossFunction = CNTK::CrossEntropyWithSoftmax(classifierOutputFunction, labelsVar, L"LossFunction");
auto predictionFunction = CNTK::ClassificationError(classifierOutputFunction, labelsVar, L"ClassificationError");
auto trainingLossFunction = CNTK::CrossEntropyWithSoftmax(classifierOutput, labelsVar, L"LossFunction");
Variable trainingLoss = trainingLossFunction;
auto predictionFunction = CNTK::ClassificationError(classifierOutput, labelsVar, L"ClassificationError");
Variable prediction = predictionFunction;
auto ffNet = CNTK::Combine({ trainingLossFunction, predictionFunction, classifierOutputFunction }, L"ClassifierModel");
auto ffNet = CNTK::Combine({ trainingLoss.Owner(), prediction.Owner(), classifierOutput.Owner() }, L"ClassifierModel");
// Now test the structure
if (ffNet->Parameters().size() != ((numHiddenLayers * 2) + 1))
@ -58,6 +68,9 @@ void TestFeedForwardNetworkCreation(const DeviceDescriptor& device)
if (ffNet->Outputs().size() != 3)
throw std::runtime_error("TestFeedForwardNetworkCreation: Function does not have expected Output count");
if (testSaveAndReLoad)
SaveAndReloadModel<float>(ffNet, { &inputVar, &labelsVar, &trainingLoss, &prediction, &classifierOutput }, device);
// Run Forward and backward a few times
size_t iterationCount = 4;
unsigned int randSeed = 2;
@ -69,22 +82,22 @@ void TestFeedForwardNetworkCreation(const DeviceDescriptor& device)
for (size_t i = 0; i < inputData.size(); ++i)
inputData[i] = ((float)rand()) / RAND_MAX;
NDShape inputShape = { inputDim, 1, numSamples };
NDShape inputShape = inputVar.Shape().AppendShape({ 1, numSamples });
ValuePtr inputValue = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(inputShape, inputData.data(), inputData.size(), DeviceDescriptor::CPUDevice(), true));
std::vector<float> labelData(numOutputClasses * numSamples, 0);
for (size_t i = 0; i < numSamples; ++i)
labelData[(i*numOutputClasses) + (rand() % numOutputClasses)] = 1;
NDShape labelShape = { numOutputClasses, 1, numSamples };
NDShape labelShape = labelsVar.Shape().AppendShape({ 1, numSamples });
ValuePtr labelValue = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(labelShape, labelData.data(), labelData.size(), DeviceDescriptor::CPUDevice(), true));
ValuePtr outputValue, predictionErrorValue;
std::unordered_map<Variable, ValuePtr> outputs = { { classifierOutputFunction->Output(), outputValue }, { predictionFunction->Output(), predictionErrorValue } };
auto backpropState = ffNet->Forward({ { inputVar, inputValue }, { labelsVar, labelValue } }, outputs, device, { trainingLossFunction->Output() });
std::unordered_map<Variable, ValuePtr> outputs = { { classifierOutput, outputValue }, { prediction, predictionErrorValue } };
auto backpropState = ffNet->Forward({ { inputVar, inputValue }, { labelsVar, labelValue } }, outputs, device, { trainingLoss });
// Perform backprop
NDShape outputShape = trainingLossFunction->Output().Shape();
NDShape outputShape = trainingLoss.Shape();
std::vector<float> rootGradientsData(outputShape.TotalSize(), 1);
ValuePtr rootGradientValue = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(outputShape, rootGradientsData.data(), rootGradientsData.size(), DeviceDescriptor::CPUDevice(), true));
std::unordered_map<Variable, ValuePtr> paramGradients;
@ -92,7 +105,7 @@ void TestFeedForwardNetworkCreation(const DeviceDescriptor& device)
for (auto iter = allParams.begin(); iter != allParams.end(); ++iter)
paramGradients[*iter] = nullptr;
ffNet->Backward(backpropState, { { trainingLossFunction->Output(), rootGradientValue } }, paramGradients);
ffNet->Backward(backpropState, { { trainingLoss, rootGradientValue } }, paramGradients);
}
}
@ -103,15 +116,19 @@ void TestTimesAndPlus(size_t inputDim,
const DeviceDescriptor& device,
size_t numIterations,
bool usePreAllocatedOutputs,
bool outputOnSpecifiedDevice = false,
bool outputOnSpecifiedDevice,
bool testSaveAndReLoad,
unsigned int seed = 1)
{
Parameter timesParam(MakeSharedObject<NDArrayView>((ElementType)0.5, NDShape({ outputDim, inputDim }), device));
Parameter plusParam(MakeSharedObject<NDArrayView>((ElementType)1.2, std::initializer_list<size_t>({ outputDim }), device));
Parameter timesParam(MakeSharedObject<NDArrayView>((ElementType)0.5, NDShape({ outputDim, inputDim }), device), L"timesParameters");
Parameter plusParam(MakeSharedObject<NDArrayView>((ElementType)1.2, std::initializer_list<size_t>({ outputDim }), device), L"plusParameters");
Variable inputVar({ inputDim }, AsDataType<ElementType>(), L"input");
auto timesAndPlusFunc = Plus(plusParam, Times(timesParam, inputVar));
if (testSaveAndReLoad)
SaveAndReloadModel<ElementType>(timesAndPlusFunc, { &inputVar, &timesParam, &plusParam }, device);
srand(seed);
for (size_t iterIdx = 0; iterIdx < numIterations; ++iterIdx)
{
@ -119,10 +136,10 @@ void TestTimesAndPlus(size_t inputDim,
for (size_t i = 0; i < inputData.size(); ++i)
inputData[i] = ((ElementType)rand()) / RAND_MAX;
NDShape inputShape = { inputDim, 1, numSamples };
NDShape inputShape = inputVar.Shape().AppendShape({ 1, numSamples });
ValuePtr inputValue = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(inputShape, inputData.data(), inputData.size(), DeviceDescriptor::CPUDevice(), true));
NDShape outputShape = { outputDim, 1, numSamples };
NDShape outputShape = timesAndPlusFunc->Output().Shape().AppendShape({ 1, numSamples });
std::vector<ElementType> outputData(outputShape.TotalSize());
ValuePtr outputValue;
if (usePreAllocatedOutputs)
@ -235,12 +252,14 @@ void TestTimesAndPlus(size_t inputDim,
void FeedForwardTests()
{
TestTimesAndPlus<double>(4, 2, 5, DeviceDescriptor::CPUDevice(), 3, true, true);
TestTimesAndPlus<double>(4, 2, 5, DeviceDescriptor::CPUDevice(), 3, true, true, true);
#ifndef CPUONLY
TestTimesAndPlus<float>(145, 32, 2, DeviceDescriptor::GPUDevice(0), 10, true, false);
TestTimesAndPlus<double>(145, 15, 200, DeviceDescriptor::GPUDevice(0), 21, false);
TestTimesAndPlus<float>(145, 32, 2, DeviceDescriptor::GPUDevice(0), 10, true, false, true);
TestTimesAndPlus<double>(145, 15, 200, DeviceDescriptor::GPUDevice(0), 21, false, false, false);
TestFeedForwardNetworkCreation(DeviceDescriptor::GPUDevice(0));
TestFeedForwardNetworkCreation(DeviceDescriptor::GPUDevice(0), true);
TestFeedForwardNetworkCreation(DeviceDescriptor::GPUDevice(0), false);
#endif
TestFeedForwardNetworkCreation(DeviceDescriptor::CPUDevice());
TestFeedForwardNetworkCreation(DeviceDescriptor::CPUDevice(), false);
TestFeedForwardNetworkCreation(DeviceDescriptor::CPUDevice(), true);
}

Просмотреть файл

@ -119,7 +119,7 @@ FunctionPtr LSTMPComponentWithSelfStabilization(Variable input, size_t outputDim
}
template <typename ElementType>
FunctionPtr LSTMNet(Variable features, size_t cellDim, size_t hiddenDim, size_t numOutputClasses, size_t numLSTMLayers, const DeviceDescriptor& device)
FunctionPtr LSTMNet(Variable features, size_t cellDim, size_t hiddenDim, size_t numOutputClasses, size_t numLSTMLayers, const DeviceDescriptor& device, const std::wstring& outputName)
{
assert(numLSTMLayers >= 1);
auto classifierRoot = LSTMPComponentWithSelfStabilization<ElementType>(features, hiddenDim, cellDim, device);
@ -133,11 +133,11 @@ FunctionPtr LSTMNet(Variable features, size_t cellDim, size_t hiddenDim, size_t
auto sW = Parameter({}, (ElementType)0.0, device);
auto expsW = Exp(sW);
return Plus(Times(W, ElementTimes(expsW, classifierRoot)), b);
return Plus(Times(W, ElementTimes(expsW, classifierRoot)), b, outputName);
}
template <typename ElementType>
void TestRecurrentNetworkCreation(const DeviceDescriptor& device)
void TestRecurrentNetworkCreation(const DeviceDescriptor& device, bool testSaveAndReLoad)
{
const size_t inputDim = 937;
const size_t numLSTMLayers = 3;
@ -146,11 +146,14 @@ void TestRecurrentNetworkCreation(const DeviceDescriptor& device)
const size_t numOutputClasses = 9304;
Variable features({ inputDim }, AsDataType<ElementType>(), L"features");
auto classifierOutputFunction = LSTMNet<ElementType>(features, cellDim, hiddenDim, numOutputClasses, numLSTMLayers, device);
auto classifierOutputFunction = LSTMNet<ElementType>(features, cellDim, hiddenDim, numOutputClasses, numLSTMLayers, device, L"classifierOutput");
Variable classifierOutput = classifierOutputFunction;
Variable labelsVar = Variable({ numOutputClasses }, AsDataType<ElementType>(), L"labels");
auto trainingLossFunction = CrossEntropyWithSoftmax(classifierOutputFunction, labelsVar, L"lossFunction");
Variable trainingLoss = trainingLossFunction;
auto predictionFunction = ClassificationError(classifierOutputFunction, labelsVar, L"classificationError");
Variable prediction = predictionFunction;
auto LSTMClassifier = Combine({ trainingLossFunction, predictionFunction, classifierOutputFunction }, L"LSTMClassifier");
@ -164,6 +167,9 @@ void TestRecurrentNetworkCreation(const DeviceDescriptor& device)
if (LSTMClassifier->Parameters().size() != ((numLSTMLayers * 28) + 3))
throw std::runtime_error("TestFeedForwardNetworkCreation: Function does not have expected Parameter count");
if (testSaveAndReLoad)
SaveAndReloadModel<ElementType>(LSTMClassifier, { &features, &labelsVar, &trainingLoss, &prediction, &classifierOutput }, device);
// Run Forward and backward a few times
size_t iterationCount = 3;
unsigned int randSeed = 2;
@ -206,11 +212,11 @@ void TestRecurrentNetworkCreation(const DeviceDescriptor& device)
ValuePtr labelValue = Value::Create({ numOutputClasses }, labelsData, device, true);
ValuePtr outputValue, predictionErrorValue;
std::unordered_map<Variable, ValuePtr> outputs = { { classifierOutputFunction->Output(), outputValue }, { predictionFunction->Output(), predictionErrorValue } };
auto backpropState = LSTMClassifier->Forward({ { features, inputValue }, { labelsVar, labelValue } }, outputs, device, { trainingLossFunction->Output() });
std::unordered_map<Variable, ValuePtr> outputs = { { classifierOutput, outputValue }, { prediction, predictionErrorValue } };
auto backpropState = LSTMClassifier->Forward({ { features, inputValue }, { labelsVar, labelValue } }, outputs, device, { trainingLoss });
// Perform backprop
NDShape outputShape = trainingLossFunction->Output().Shape();
NDShape outputShape = trainingLoss.Shape();
std::vector<ElementType> rootGradientsData(outputShape.TotalSize(), 1);
ValuePtr rootGradientValue = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(outputShape, rootGradientsData.data(), rootGradientsData.size(), DeviceDescriptor::CPUDevice(), true));
std::unordered_map<Variable, ValuePtr> paramGradients;
@ -218,7 +224,7 @@ void TestRecurrentNetworkCreation(const DeviceDescriptor& device)
for (auto iter = allParams.begin(); iter != allParams.end(); ++iter)
paramGradients[*iter] = nullptr;
LSTMClassifier->Backward(backpropState, { { trainingLossFunction->Output(), rootGradientValue } }, paramGradients);
LSTMClassifier->Backward(backpropState, { { trainingLoss, rootGradientValue } }, paramGradients);
}
}
@ -228,6 +234,7 @@ void TestSimpleRecurrence(size_t inputDim,
size_t maxAllowedSequenceLength,
size_t numSequences,
const DeviceDescriptor& device,
bool testSaveAndReLoad,
size_t numIterations,
bool useFutureValue,
bool useSparseInputs,
@ -237,24 +244,29 @@ void TestSimpleRecurrence(size_t inputDim,
if (useOneHotSparseInputs && !useSparseInputs)
throw std::runtime_error("useOneHotSparseInputs option can only be true when useSparseInputs is true");
Parameter timesParam(MakeSharedObject<NDArrayView>((ElementType)0.5, NDShape({ outputDim, inputDim }), device));
Parameter plusParam(MakeSharedObject<NDArrayView>((ElementType)0.1, std::initializer_list<size_t>({ outputDim }), device));
Parameter timesParam(MakeSharedObject<NDArrayView>((ElementType)0.5, NDShape({ outputDim, inputDim }), device), L"timesParameters");
Parameter plusParam(MakeSharedObject<NDArrayView>((ElementType)0.1, std::initializer_list<size_t>({ outputDim }), device), L"plusParameters");
Variable inputVar({ inputDim }, useSparseInputs, AsDataType<ElementType>(), true, L"input");
auto placeholder = Placeholder({ outputDim });
auto plusOutput = Plus(plusParam, Plus(placeholder, Times(timesParam, inputVar)));
auto plusOutputFunction = Plus(plusParam, Plus(placeholder, Times(timesParam, inputVar)), L"plusOutput");
FunctionPtr placeholderReplacement;
if (useFutureValue)
placeholderReplacement = FutureValue(Constant({}, (ElementType)0.0, device), plusOutput, 1);
placeholderReplacement = FutureValue(Constant({}, (ElementType)0.0, device), plusOutputFunction, 1);
else
placeholderReplacement = PastValue(Constant({}, (ElementType)0.0, device), plusOutput, 1);
placeholderReplacement = PastValue(Constant({}, (ElementType)0.0, device), plusOutputFunction, 1);
plusOutput = plusOutput->ReplacePlaceholders({ { placeholder, placeholderReplacement } });
plusOutputFunction = plusOutputFunction->ReplacePlaceholders({ { placeholder, placeholderReplacement } });
Variable plusOutput = plusOutputFunction;
auto reducedOutput = ReduceSum(plusOutput);
auto reducedOutputFunction = ReduceSum(plusOutput, L"sum");
Variable reducedOutput = reducedOutputFunction;
auto rootFunc = Combine({ reducedOutput, plusOutput });
auto rootFunc = Combine({ reducedOutputFunction, plusOutputFunction });
if (testSaveAndReLoad)
SaveAndReloadModel<ElementType>(rootFunc, { &inputVar, &timesParam, &plusParam, &plusOutput, &reducedOutput }, device);
srand(seed);
for (size_t iterIdx = 0; iterIdx < numIterations; ++iterIdx)
@ -268,7 +280,7 @@ void TestSimpleRecurrence(size_t inputDim,
maxActualSequenceLength = sequenceLengths[i];
}
NDShape inputShape = { inputDim, maxActualSequenceLength, numSequences };
NDShape inputShape = inputVar.Shape().AppendShape({ maxActualSequenceLength, numSequences });
ValuePtr inputValue;
size_t totalNumInputSamples = maxActualSequenceLength * numSequences;
std::vector<ElementType> inputData(inputDim * totalNumInputSamples, useSparseInputs ? 0 : std::numeric_limits<ElementType>::quiet_NaN());
@ -330,12 +342,12 @@ void TestSimpleRecurrence(size_t inputDim,
std::vector<ElementType> reducedOutputData(reducedOutputShape.TotalSize());
ValuePtr reducedOutputValue = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(reducedOutputShape, reducedOutputData.data(), reducedOutputData.size(), DeviceDescriptor::CPUDevice(), false));
NDShape plusOutputShape = plusOutput->Output().Shape().AppendShape({ maxActualSequenceLength, numSequences });
NDShape plusOutputShape = plusOutput.Shape().AppendShape({ maxActualSequenceLength, numSequences });
std::vector<ElementType> plusOutputData(plusOutputShape.TotalSize());
ValuePtr plusOutputValue = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(plusOutputShape, plusOutputData.data(), plusOutputData.size(), DeviceDescriptor::CPUDevice(), false), MakeSharedObject<NDMask>(inputValue->Mask()->Shape(), inputValue->Mask()->Device()));
std::unordered_map<Variable, ValuePtr> outputs = { { reducedOutput->Output(), reducedOutputValue }, { plusOutput->Output(), plusOutputValue } };
auto backpropState = rootFunc->Forward({ { inputVar, inputValue } }, outputs, device, { plusOutput->Output() });
std::unordered_map<Variable, ValuePtr> outputs = { { reducedOutput, reducedOutputValue }, { plusOutput, plusOutputValue } };
auto backpropState = rootFunc->Forward({ { inputVar, inputValue } }, outputs, device, { plusOutput });
// Perform backprop
std::vector<ElementType> rootGradientsData(plusOutputShape.TotalSize(), std::numeric_limits<ElementType>::quiet_NaN());
@ -362,7 +374,7 @@ void TestSimpleRecurrence(size_t inputDim,
ValuePtr inputGradientValue = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(inputShape, inputGradientData.data(), inputGradientData.size(), DeviceDescriptor::CPUDevice(), false), inputValue->Mask()->DeepClone());
std::unordered_map<Variable, ValuePtr> outGradients = { { inputVar, inputGradientValue }, { plusParam, plusParameterGradientValue }, { timesParam, timesParameterGradientValue } };
rootFunc->Backward(backpropState, { { plusOutput->Output(), rootGradientValue } }, outGradients);
rootFunc->Backward(backpropState, { { plusOutput, rootGradientValue } }, outGradients);
// Verify forward prop results
std::vector<ElementType> expectedPlusOutputData(plusOutputShape.TotalSize(), 0);
@ -473,19 +485,19 @@ void TestSimpleRecurrence(size_t inputDim,
void RecurrentFunctionTests()
{
TestSimpleRecurrence<float>(2, 1, 4, 1, DeviceDescriptor::CPUDevice(), 3, false, false);
TestSimpleRecurrence<float>(2, 1, 4, 1, DeviceDescriptor::CPUDevice(), true, 3, false, false);
#ifndef CPUONLY
TestSimpleRecurrence<double>(11, 9, 16, 7, DeviceDescriptor::GPUDevice(0), 5, true, false);
TestSimpleRecurrence<double>(11, 9, 16, 7, DeviceDescriptor::GPUDevice(0), true, 5, true, false);
#endif
TestSimpleRecurrence<double>(1000, 9, 16, 3, DeviceDescriptor::CPUDevice(), 2, true, true);
TestSimpleRecurrence<double>(1000, 9, 16, 3, DeviceDescriptor::CPUDevice(), false, 2, true, true);
#ifndef CPUONLY
TestSimpleRecurrence<float>(5000, 200, 19, 6, DeviceDescriptor::GPUDevice(0), 3, false, true);
TestSimpleRecurrence<double>(1000, 9, 16, 3, DeviceDescriptor::GPUDevice(0), 3, true, true, true);
TestSimpleRecurrence<float>(5000, 200, 19, 6, DeviceDescriptor::GPUDevice(0), false, 3, false, true);
TestSimpleRecurrence<double>(1000, 9, 16, 3, DeviceDescriptor::GPUDevice(0), true, 3, true, true, true);
#endif
TestSimpleRecurrence<float>(5000, 200, 19, 6, DeviceDescriptor::CPUDevice(), 2, false, true, true);
TestSimpleRecurrence<float>(5000, 200, 19, 6, DeviceDescriptor::CPUDevice(), true, 2, false, true, true);
#ifndef CPUONLY
TestRecurrentNetworkCreation<float>(DeviceDescriptor::GPUDevice(0));
TestRecurrentNetworkCreation<float>(DeviceDescriptor::GPUDevice(0), true);
#endif
TestRecurrentNetworkCreation<double>(DeviceDescriptor::CPUDevice());
TestRecurrentNetworkCreation<double>(DeviceDescriptor::CPUDevice(), false);
}

Просмотреть файл

@ -20,8 +20,8 @@ void TestTensorPlus(size_t numAxesLeftOperand, size_t numAxesRightOperand, const
for (size_t i = std::min(numAxesLeftOperand, numAxesRightOperand); i < numAxesRightOperand; ++i)
rightInputShape[i] = (rand() % maxDimSize) + 1;
Variable leftInputVar(leftInputShape, AsDataType<ElementType>(), L"leftInput");
Variable rightInputVar(rightInputShape, AsDataType<ElementType>(), L"rightInput");
Variable leftInputVar(leftInputShape, AsDataType<ElementType>(), true, L"leftInput");
Variable rightInputVar(rightInputShape, AsDataType<ElementType>(), true, L"rightInput");
auto plusFunc = Plus(leftInputVar, rightInputVar);