From abf468097b40a7ad8082f313ee5c82b8a18212e1 Mon Sep 17 00:00:00 2001 From: Amit Agarwal Date: Sat, 3 Sep 2016 17:31:02 -0700 Subject: [PATCH] CNTK v2 library: Trainer checkpointing and some MinibatchSource creation helper APIs --- Source/CNTKv2LibraryDll/API/CNTKLibrary.h | 187 ++++++++++++++---- .../API/CNTKLibraryInternals.h | 5 + Source/CNTKv2LibraryDll/BackCompat.cpp | 10 +- Source/CNTKv2LibraryDll/Common.cpp | 9 + Source/CNTKv2LibraryDll/Function.cpp | 39 +--- Source/CNTKv2LibraryDll/Function.h | 1 + Source/CNTKv2LibraryDll/Learner.cpp | 23 +-- Source/CNTKv2LibraryDll/NDArrayView.cpp | 6 +- Source/CNTKv2LibraryDll/Trainer.cpp | 86 +++++--- Source/CNTKv2LibraryDll/Utils.h | 41 ++++ Tests/UnitTests/V2LibraryTests/Common.h | 45 +---- Tests/UnitTests/V2LibraryTests/Seq2Seq.cpp | 69 ++----- .../V2LibraryTests/SequenceClassification.cpp | 2 +- .../UnitTests/V2LibraryTests/TrainerTests.cpp | 15 +- 14 files changed, 315 insertions(+), 223 deletions(-) diff --git a/Source/CNTKv2LibraryDll/API/CNTKLibrary.h b/Source/CNTKv2LibraryDll/API/CNTKLibrary.h index 5e03ffd18..91b4ff73e 100644 --- a/Source/CNTKv2LibraryDll/API/CNTKLibrary.h +++ b/Source/CNTKv2LibraryDll/API/CNTKLibrary.h @@ -501,10 +501,19 @@ namespace CNTK /// Fill 'this' NDArrayView with the specified value. The underlying DataType of 'this' view should be DataType::Double. /// CNTK_API void SetValue(double value); + + /// + /// Creates a new NDArrayView with newly allocated storage on the specified device and copies 'this' view's contents into the newly allocated view. + /// + CNTK_API NDArrayViewPtr DeepClone(const DeviceDescriptor& device, bool readOnly = false) const; + /// /// Creates a new NDArrayView with newly allocated storage on the same device as 'this' view and copies 'this' view's contents into the newly allocated view. /// - CNTK_API NDArrayViewPtr DeepClone(bool readOnly = false) const; + inline NDArrayViewPtr DeepClone(bool readOnly = false) const + { + return DeepClone(this->Device(), readOnly); + } /// /// Creates a new NDArrayView which is an alias of 'this' view; i.e. a new view of the same shape as 'this' over the same underlying data. @@ -854,6 +863,33 @@ namespace CNTK Placeholder }; + inline const wchar_t* VariableKindName(VariableKind variableKind) + { + switch (variableKind) + { + case VariableKind::Input: + return L"Input"; + case VariableKind::Output: + return L"Output"; + case VariableKind::Parameter: + return L"Parameter"; + case VariableKind::Constant: + return L"Constant"; + case VariableKind::Placeholder: + return L"Placeholder"; + default: + LogicError("Unknown VariableKind"); + } + } + + namespace Internal + { + inline std::wstring GenerateUid(VariableKind varKind) + { + return std::wstring(VariableKindName(varKind)) + std::to_wstring(Internal::NewUniqueId()); + } + } + // Forward declarations inline Variable PlaceholderVariable(const NDShape& shape, const std::vector& dynamicAxes = Axis::DefaultInputVariableDynamicAxes); inline Variable InputVariable(const NDShape& shape, bool isSparse, CNTK::DataType dataType, bool needsGradient, const std::wstring& name = L"", const std::vector& dynamicAxes = Axis::DefaultInputVariableDynamicAxes); @@ -874,18 +910,18 @@ namespace CNTK template friend struct std::hash; + template + friend Variable GetVariable(const Microsoft::MSR::CNTK::ComputationNodeBasePtr& node, + std::unordered_map& nodeToVariableMap, + std::unordered_map& placeholderReplacements, + std::unordered_set& allPrimitiveFunctions); + private: friend inline Variable PlaceholderVariable(const NDShape& shape, const std::vector& dynamicAxes /*= Axis::DefaultInputVariableDynamicAxes*/); friend inline Variable InputVariable(const NDShape& shape, bool isSparse, CNTK::DataType dataType, bool needsGradient, const std::wstring& name /*= L""*/, const std::vector& dynamicAxes /*= Axis::DefaultInputVariableDynamicAxes*/); friend inline Variable OutputVariable(const NDShape& shape, CNTK::DataType dataType, Function* ownerFunction, const std::vector& dynamicAxes, const std::wstring& name /*= L""*/); - public: - /// - /// Create an 'Input' Variable - /// - Variable(const NDShape& shape, bool isSparse, CNTK::DataType dataType, bool needsGradient, const std::wstring& name, const std::vector& dynamicAxes = Axis::DefaultInputVariableDynamicAxes) - : Variable(shape, VariableKind::Input, dataType, nullptr, nullptr, needsGradient, dynamicAxes, isSparse, name) - {} + public: /// /// Create an 'Output' variable aliasing the output of the specified Function @@ -954,6 +990,11 @@ namespace CNTK /// const std::wstring& Name() const { return m_dataFields->m_name; } + /// + /// Returns the internally generated unique name of the variable + /// + const std::wstring& Uid() const { return m_dataFields->m_uid; } + /// /// Returns the Function object which 'this' variable is an ouptut of. /// Returns null when called for a Variable that is not of 'Output' VariableKind. @@ -971,8 +1012,8 @@ namespace CNTK bool NeedsGradient() const { return m_dataFields->m_needsGradient; } protected: - Variable(const NDShape& shape, VariableKind varType, CNTK::DataType dataType, const NDArrayViewPtr& value, bool needsGradient, const std::vector& dynamicAxes, const std::wstring& name) - : Variable(shape, varType, dataType, nullptr, value, needsGradient, dynamicAxes, /*isSparse =*/ false, name) + Variable(const NDShape& shape, VariableKind varType, CNTK::DataType dataType, const NDArrayViewPtr& value, bool needsGradient, const std::vector& dynamicAxes, const std::wstring& name, const std::wstring& uid) + : Variable(shape, varType, dataType, nullptr, value, needsGradient, dynamicAxes, /*isSparse =*/ false, name, uid) {} NDArrayViewPtr Value() const @@ -982,8 +1023,13 @@ namespace CNTK } private: - Variable(const NDShape& shape, VariableKind varType, CNTK::DataType dataType, Function* ownerFunction, const NDArrayViewPtr& value, bool needsGradient, const std::vector& dynamicAxes, bool isSparse, const std::wstring& name) - : m_dataFields(MakeSharedObject(shape, varType, dataType, ownerFunction, value, needsGradient, dynamicAxes, isSparse, name)) + Variable(const NDShape& shape, bool isSparse, CNTK::DataType dataType, bool needsGradient, const std::wstring& name, const std::vector& dynamicAxes, const std::wstring& uid) + : Variable(shape, VariableKind::Input, dataType, nullptr, nullptr, needsGradient, dynamicAxes, isSparse, name, uid) + {} + + + Variable(const NDShape& shape, VariableKind varType, CNTK::DataType dataType, Function* ownerFunction, const NDArrayViewPtr& value, bool needsGradient, const std::vector& dynamicAxes, bool isSparse, const std::wstring& name, const std::wstring& uid) + : m_dataFields(MakeSharedObject(shape, varType, dataType, ownerFunction, value, needsGradient, dynamicAxes, isSparse, name, uid)) {} private: @@ -1001,9 +1047,10 @@ namespace CNTK std::wstring m_name; std::vector m_dynamicAxes; bool m_isSparse; + std::wstring m_uid; - VariableFields(const NDShape& shape, VariableKind varType, CNTK::DataType type, Function* ownerFunction, const NDArrayViewPtr& value, bool needsGradient, const std::vector& dynamicAxes, bool isSparse, const std::wstring& name) - : m_shape(shape), m_varKind(varType), m_dataType(type), m_ownerFunction(ownerFunction), m_value(value), m_needsGradient(needsGradient), m_dynamicAxes(dynamicAxes), m_isSparse(isSparse), m_name(name) + VariableFields(const NDShape& shape, VariableKind varType, CNTK::DataType type, Function* ownerFunction, const NDArrayViewPtr& value, bool needsGradient, const std::vector& dynamicAxes, bool isSparse, const std::wstring& name, const std::wstring& uid) + : m_shape(shape), m_varKind(varType), m_dataType(type), m_ownerFunction(ownerFunction), m_value(value), m_needsGradient(needsGradient), m_dynamicAxes(dynamicAxes), m_isSparse(isSparse), m_name(name), m_uid(uid) { if (value && (type != value->GetDataType())) InvalidArgument("The DataType of the Parameter/Constant Variable does not match the DataType of the associated Value"); @@ -1043,7 +1090,16 @@ namespace CNTK /// inline Variable PlaceholderVariable(const NDShape& shape, const std::vector& dynamicAxes /*= Axis::DefaultInputVariableDynamicAxes*/) { - return Variable(shape, VariableKind::Placeholder, DataType::Unknown, nullptr, false, dynamicAxes, L""); + auto varKind = VariableKind::Placeholder; + return Variable(shape, varKind, DataType::Unknown, nullptr, false, dynamicAxes, L"", Internal::GenerateUid(varKind)); + } + + /// + /// Create an 'Input' Variable denoting sparse data and specify if gradients are to be computed for this input + /// + inline Variable InputVariable(const NDShape& shape, bool isSparse, CNTK::DataType dataType, bool needsGradient, const std::wstring& name /*= L""*/, const std::vector& dynamicAxes /*= Axis::DefaultInputVariableDynamicAxes*/) + { + return Variable(shape, isSparse, dataType, needsGradient, name, dynamicAxes, Internal::GenerateUid(VariableKind::Input)); } /// @@ -1051,7 +1107,7 @@ namespace CNTK /// inline Variable InputVariable(const NDShape& shape, CNTK::DataType dataType, bool needsGradient, const std::wstring& name = L"", const std::vector& dynamicAxes = Axis::DefaultInputVariableDynamicAxes) { - return Variable(shape, /*isSparse =*/ false, dataType, needsGradient, name, dynamicAxes); + return InputVariable(shape, /*isSparse =*/ false, dataType, needsGradient, name, dynamicAxes); } /// @@ -1078,14 +1134,6 @@ namespace CNTK return InputVariable(shape, dataType, L"", dynamicAxes); } - /// - /// Create an 'Input' Variable denoting sparse data and specify if gradients are to be computed for this input - /// - inline Variable InputVariable(const NDShape& shape, bool isSparse, CNTK::DataType dataType, bool needsGradient, const std::wstring& name /*= L""*/, const std::vector& dynamicAxes /*= Axis::DefaultInputVariableDynamicAxes*/) - { - return Variable(shape, isSparse, dataType, needsGradient, name, dynamicAxes); - } - /// /// Create an 'Input' Variable denoting sparse data. /// @@ -1115,7 +1163,7 @@ namespace CNTK /// inline Variable OutputVariable(const NDShape& shape, CNTK::DataType dataType, Function* ownerFunction, const std::vector& dynamicAxes, const std::wstring& name /*= L""*/) { - return Variable(shape, VariableKind::Output, dataType, ownerFunction, nullptr, /*needsGradient =*/ false, dynamicAxes, /*isSparse =*/ false, name); + return Variable(shape, VariableKind::Output, dataType, ownerFunction, nullptr, /*needsGradient =*/ false, dynamicAxes, /*isSparse =*/ false, name, Internal::GenerateUid(VariableKind::Output)); } /// @@ -1126,12 +1174,18 @@ namespace CNTK template friend struct std::hash; + template + friend Variable GetVariable(const Microsoft::MSR::CNTK::ComputationNodeBasePtr& node, + std::unordered_map& nodeToVariableMap, + std::unordered_map& placeholderReplacements, + std::unordered_set& allPrimitiveFunctions); + public: /// /// Construct a parameter whose initial contents are a copy of the specified 'value' /// explicit Parameter(const NDArrayViewPtr& value, const std::wstring& name = L"") - : Variable(value->Shape(), VariableKind::Parameter, value->GetDataType(), value->DeepClone(), true, {}, name) + : Parameter(value, name, Internal::GenerateUid(VariableKind::Parameter)) {} // TODO: Constructor to move a specified NDArrayView value @@ -1141,14 +1195,14 @@ namespace CNTK /// template Parameter(const NDShape& shape, ElemType initValue, const DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice(), const std::wstring& name = L"") - : Variable(shape, VariableKind::Parameter, AsDataType(), MakeSharedObject(initValue, shape, device), true, {}, name) + : Variable(shape, VariableKind::Parameter, AsDataType(), MakeSharedObject(initValue, shape, device), true, {}, name, Internal::GenerateUid(VariableKind::Parameter)) {} /// /// Construct a constant of specified shape whose contents are initialized with the specified 'initValue' /// Parameter(const NDShape& shape, DataType dataType, double initValue, const DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice(), const std::wstring& name = L"") - : Variable(shape, VariableKind::Parameter, dataType, MakeSharedObject(initValue, dataType, shape, device), true, {}, name) + : Variable(shape, VariableKind::Parameter, dataType, MakeSharedObject(initValue, dataType, shape, device), true, {}, name, Internal::GenerateUid(VariableKind::Parameter)) {} /// @@ -1233,6 +1287,11 @@ namespace CNTK return Variable::Value(); } + private: + explicit Parameter(const NDArrayViewPtr& value, const std::wstring& name, const std::wstring& uid) + : Variable(value->Shape(), VariableKind::Parameter, value->GetDataType(), value->DeepClone(), true, {}, name, uid) + {} + private: // Helper methods for Parameter construction @@ -1254,12 +1313,18 @@ namespace CNTK template friend struct std::hash; + template + friend Variable GetVariable(const Microsoft::MSR::CNTK::ComputationNodeBasePtr& node, + std::unordered_map& nodeToVariableMap, + std::unordered_map& placeholderReplacements, + std::unordered_set& allPrimitiveFunctions); + public: /// /// Contruct a Constant whose initial contents are a copy of the specified value /// Constant(const NDArrayViewPtr& value, const std::wstring& name = L"") - : Variable(value->Shape(), VariableKind::Constant, value->GetDataType(), value->DeepClone(true), false, {}, name) + : Constant(value, name, Internal::GenerateUid(VariableKind::Constant)) {} // TODO: Constructor to move a specified NDArrayView value @@ -1269,14 +1334,14 @@ namespace CNTK /// template Constant(const NDShape& shape, ElemType initValue, const DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice(), const std::wstring& name = L"") - : Variable(shape, VariableKind::Constant, AsDataType(), MakeSharedObject(initValue, shape, device), false, {}, name) + : Variable(shape, VariableKind::Constant, AsDataType(), MakeSharedObject(initValue, shape, device), false, {}, name, Internal::GenerateUid(VariableKind::Constant)) {} /// /// Construct a constant of specified shape whose contents are initialized with the specified 'initValue' /// Constant(const NDShape& shape, DataType dataType, double initValue, const DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice(), const std::wstring& name = L"") - : Variable(shape, VariableKind::Constant, dataType, MakeSharedObject(initValue, dataType, shape, device), false, {}, name) + : Variable(shape, VariableKind::Constant, dataType, MakeSharedObject(initValue, dataType, shape, device), false, {}, name, Internal::GenerateUid(VariableKind::Constant)) {} /// @@ -1313,6 +1378,11 @@ namespace CNTK { return Variable::Value(); } + + private: + Constant(const NDArrayViewPtr& value, const std::wstring& name, const std::wstring& uid) + : Variable(value->Shape(), VariableKind::Constant, value->GetDataType(), value->DeepClone(true), false, {}, name, uid) + {} }; // Implementation note: The Variable type is a value type and not polymorphic in nature. @@ -1799,8 +1869,8 @@ namespace CNTK /// inline FunctionPtr PastValue(const Variable& operand, size_t offset = 1, const std::wstring& name = L"") { - const Variable& initialState = Constant::Scalar(0.0f); - return PastValue(operand, initialState, offset, name); + static const auto defaultInitialState = Constant::Scalar(0.0f); + return PastValue(operand, defaultInitialState, offset, name); } /// @@ -1816,8 +1886,8 @@ namespace CNTK /// inline FunctionPtr FutureValue(const Variable& operand, size_t offset = 1, const std::wstring& name = L"") { - const Variable& initialState = Constant::Scalar(0.0f); - return FutureValue(operand, initialState, offset, name); + static const auto defaultInitialState = Constant::Scalar(0.0f); + return FutureValue(operand, defaultInitialState, offset, name); } /// @@ -2722,6 +2792,53 @@ namespace CNTK /// CNTK_API MinibatchSourcePtr CreateCompositeMinibatchSource(const Dictionary& configuration); + struct StreamConfiguration + { + StreamConfiguration(const std::wstring& streamName, size_t dim, bool isSparse = false, const std::wstring& streamAlias = L"") + : m_streamName(streamName), m_dim(dim), m_isSparse(isSparse), m_streamAlias(streamAlias) + {} + + std::wstring m_streamName; + size_t m_dim; + bool m_isSparse; + std::wstring m_streamAlias; + }; + + /// + /// Instantiate the CNTK buil-in test format minibatch source + /// + inline MinibatchSourcePtr TextFormatMinibatchSource(const std::wstring& dataFilePath, const std::vector& streamConfigs, size_t epochSize = SIZE_MAX) + { + CNTK::Dictionary minibatchSourceConfiguration; + minibatchSourceConfiguration[L"epochSize"] = epochSize; + + CNTK::Dictionary deserializerConfiguration; + deserializerConfiguration[L"type"] = L"CNTKTextFormatDeserializer"; + deserializerConfiguration[L"file"] = dataFilePath; + + CNTK::Dictionary inputStreamsConfig; + for (auto streamConfig : streamConfigs) + { + std::wstring streamName = streamConfig.m_streamName; + size_t streamDim = streamConfig.m_dim; + bool isSparse = streamConfig.m_isSparse; + std::wstring streamAlias = streamConfig.m_streamAlias; + + CNTK::Dictionary inputStreamConfig; + inputStreamConfig[L"dim"] = streamDim; + inputStreamConfig[L"format"] = isSparse ? L"sparse" : L"dense"; + if (!streamAlias.empty()) + inputStreamConfig[L"alias"] = streamAlias; + + inputStreamsConfig[streamName] = inputStreamConfig; + } + + deserializerConfiguration[L"input"] = inputStreamsConfig; + minibatchSourceConfiguration[L"deserializers"] = std::vector({ deserializerConfiguration }); + + return CreateCompositeMinibatchSource(minibatchSourceConfiguration); + } + /// /// Compute the per dimension means and variances for each of the specified streams using data from the specified minibatchSource. /// diff --git a/Source/CNTKv2LibraryDll/API/CNTKLibraryInternals.h b/Source/CNTKv2LibraryDll/API/CNTKLibraryInternals.h index 039ce28b3..9feadb03d 100644 --- a/Source/CNTKv2LibraryDll/API/CNTKLibraryInternals.h +++ b/Source/CNTKv2LibraryDll/API/CNTKLibraryInternals.h @@ -53,6 +53,9 @@ namespace Microsoft { namespace MSR { namespace CNTK { template class ComputationNode; + + class ComputationNodeBase; + typedef std::shared_ptr ComputationNodeBasePtr; }}} // TODO: The following should be reconciled with the equivalent code in the CNTK implementation @@ -195,5 +198,7 @@ namespace CNTK CNTK_API FunctionPtr Scatter(const Variable& operand, const Variable& condition, const std::vector& newDynamicAxes, const std::wstring& name = L""); CNTK_API FunctionPtr Slice(const Variable& operand, const Axis& axis, int beginIndex, int endIndex, const std::wstring& name = L""); CNTK_API FunctionPtr ReduceElements(const Variable& operand, const std::wstring& reductionOpName, const Axis& axis, const std::wstring& name = L""); + + CNTK_API size_t NewUniqueId(); } } diff --git a/Source/CNTKv2LibraryDll/BackCompat.cpp b/Source/CNTKv2LibraryDll/BackCompat.cpp index 8ba43d6e1..2c77f6db7 100644 --- a/Source/CNTKv2LibraryDll/BackCompat.cpp +++ b/Source/CNTKv2LibraryDll/BackCompat.cpp @@ -45,7 +45,7 @@ namespace CNTK auto inputNodeInternalDynamicAxisName = node->GetMBLayout()->GetAxisName(); std::vector inputVarDynamicAxes = DynamicAxesFromInternalDynamicAxisName(inputNodeInternalDynamicAxisName); - var = Variable(varShape, isSparse, AsDataType(), node->GetLearningRateMultiplier() != 0, node->GetName(), inputVarDynamicAxes); + var = Variable(varShape, isSparse, AsDataType(), node->GetLearningRateMultiplier() != 0, node->NodeName(), inputVarDynamicAxes, node->NodeName()); } else { @@ -58,11 +58,11 @@ namespace CNTK bool isConstant = (node->GetLearningRateMultiplier() == 0); auto& matrix = node->As>()->Value(); auto tensorView = new TensorView(std::make_shared>(matrix.AsReference()), AsTensorViewShape(node->GetSampleLayout())); - NDArrayViewPtr parameterValue = MakeSharedObject(AsDataType(), AsDeviceDescriptor(matrix.GetDeviceId()), AsStorageFormat(matrix.GetFormat()), varShape, false, tensorView); + NDArrayViewPtr value = MakeSharedObject(AsDataType(), AsDeviceDescriptor(matrix.GetDeviceId()), AsStorageFormat(matrix.GetFormat()), varShape, false, tensorView); if (isConstant) - var = Constant(parameterValue, node->GetName()); + var = Constant(value, node->NodeName(), node->NodeName()); else - var = Parameter(parameterValue, node->GetName()); + var = Parameter(value, node->NodeName(), node->NodeName()); } else LogicError("CNTK::LoadLegacyModel: Unsupported legacy CNTK node named '%S'", node->NodeName().c_str()); @@ -276,7 +276,7 @@ namespace CNTK // Let's reorder inputVars properly since the ordering of inputs of CNTK internal ComputationNode may be different from the PrimitiveFunction inputs ordering ReorderAsPrimitiveFunctionInputs(opType, inputVars); - FunctionPtr primitiveFunction = MakeSharedObject(opType, inputVars, std::move(primitiveFunctionConfigParameters), node->GetName()); + FunctionPtr primitiveFunction = MakeSharedObject(opType, inputVars, std::move(primitiveFunctionConfigParameters), node->NodeName()); allPrimitiveFunctions.insert(primitiveFunction); var = primitiveFunction->Output(); if (placeholderReplacements.find(placeholderVar) != placeholderReplacements.end()) diff --git a/Source/CNTKv2LibraryDll/Common.cpp b/Source/CNTKv2LibraryDll/Common.cpp index 5c17c301f..bd7669670 100644 --- a/Source/CNTKv2LibraryDll/Common.cpp +++ b/Source/CNTKv2LibraryDll/Common.cpp @@ -8,6 +8,15 @@ namespace CNTK { + namespace Internal + { + size_t NewUniqueId() + { + static std::atomic s_nextUniqueId = 0; + return s_nextUniqueId++; + } + } + /*static*/ std::atomic DeviceDescriptor::s_defaultDeviceFrozen(false); /*static*/ std::shared_ptr DeviceDescriptor::s_defaultDevice(new DeviceDescriptor(DeviceDescriptor::GPUDevice(0))); diff --git a/Source/CNTKv2LibraryDll/Function.cpp b/Source/CNTKv2LibraryDll/Function.cpp index 1504e6c47..8655c7cd1 100644 --- a/Source/CNTKv2LibraryDll/Function.cpp +++ b/Source/CNTKv2LibraryDll/Function.cpp @@ -451,15 +451,11 @@ namespace CNTK std::shared_ptr> computationNodePtr; if (variable.IsParameter() || variable.IsConstant()) { - computationNodePtr = builder.CreateLearnableParameter(variable.Name(), AsTensorShape(variable.Shape())); + computationNodePtr = builder.CreateLearnableParameter(variable.Uid(), AsTensorShape(variable.Shape())); network->InitLearnableParameters(computationNodePtr, L"fixedValue", 0); // must call this to follow protocol; can overwrite later if (!variable.NeedsGradient()) computationNodePtr->SetLearningRateMultiplier(0.0); - // If the parameter variable does not have a name assign it the internal computation node name - if (variable.Name().empty()) - variable.m_dataFields->m_name = computationNodePtr->NodeName(); - NDArrayViewPtr value = variable.IsConstant() ? Constant(variable).Value() : Parameter(variable).Value(); std::shared_ptr> valueMatrix = variable.IsConstant() ? value->GetMatrix() : value->GetWritableMatrix(); if (variable.IsParameter() || (valueMatrix->GetDeviceId() == network->GetDeviceId())) @@ -493,9 +489,9 @@ namespace CNTK network->AddNodeToNetAndAttachInputs(New>(network->GetDeviceId(), internalDynamicAxisName), {}); if (IsSparseInput(variable)) - computationNodePtr = builder.CreateSparseInputNode(variable.Name(), AsTensorShape(variable.Shape()), internalDynamicAxisName); + computationNodePtr = builder.CreateSparseInputNode(variable.Uid(), AsTensorShape(variable.Shape()), internalDynamicAxisName); else - computationNodePtr = builder.CreateInputNode(variable.Name(), AsTensorShape(variable.Shape()), internalDynamicAxisName); + computationNodePtr = builder.CreateInputNode(variable.Uid(), AsTensorShape(variable.Shape()), internalDynamicAxisName); if (variable.NeedsGradient()) { @@ -796,36 +792,11 @@ namespace CNTK // If the inputVar is a constant and not the right DataType lets cast it to the right type if (inputVar.IsConstant() && (nonConstInputDataType != DataType::Unknown) && (inputVar.GetDataType() != nonConstInputDataType)) { - auto constantValue = Constant(inputVar).Value(); - NDArrayView constantValueCPU(constantValue->GetDataType(), constantValue->Shape(), DeviceDescriptor::CPUDevice()); - constantValueCPU.CopyFrom(*constantValue); - - NDArrayViewPtr newConstantValue; - if (inputVar.GetDataType() == DataType::Float) - { - // Cast to double - const float* buffer = constantValueCPU.DataBuffer(); - double* castValue = new double[constantValueCPU.Shape().TotalSize()]; - for (size_t i = 0; i < constantValueCPU.Shape().TotalSize(); ++i) - castValue[i] = buffer[i]; - - newConstantValue = MakeSharedObject(constantValue->Shape(), castValue, constantValueCPU.Shape().TotalSize(), DeviceDescriptor::CPUDevice()); - } - else - { - // Cast to float - const double* buffer = constantValueCPU.DataBuffer(); - float* castValue = new float[constantValueCPU.Shape().TotalSize()]; - for (size_t i = 0; i < constantValueCPU.Shape().TotalSize(); ++i) - castValue[i] = (float)(buffer[i]); - - newConstantValue = MakeSharedObject(constantValue->Shape(), castValue, constantValueCPU.Shape().TotalSize(), DeviceDescriptor::CPUDevice()); - } - + auto constantValueCPU = Constant(inputVar).Value()->DeepClone(DeviceDescriptor::CPUDevice(), true); + NDArrayViewPtr newConstantValue = CloneAsDataType(constantValueCPU, nonConstInputDataType, true); inputVar = Constant(newConstantValue); } - auto baseNodePtr = GetNode(inputVar, network, builder, variableToNodeMap, isVariableRootMap); inputNodes.push_back((baseNodePtr != nullptr) ? baseNodePtr->template As>()->shared_from_this() : nullptr); } diff --git a/Source/CNTKv2LibraryDll/Function.h b/Source/CNTKv2LibraryDll/Function.h index 513dd5b44..fa305db58 100644 --- a/Source/CNTKv2LibraryDll/Function.h +++ b/Source/CNTKv2LibraryDll/Function.h @@ -462,6 +462,7 @@ namespace CNTK class CompositeFunction final : public Function { friend class Function; + friend class Trainer; friend class CompositeMinibatchSource; template diff --git a/Source/CNTKv2LibraryDll/Learner.cpp b/Source/CNTKv2LibraryDll/Learner.cpp index 515a58c96..8995549c8 100644 --- a/Source/CNTKv2LibraryDll/Learner.cpp +++ b/Source/CNTKv2LibraryDll/Learner.cpp @@ -216,12 +216,12 @@ namespace CNTK const auto& gradientValue = gradientValues.at(parameter); // TODO: make this a runtime parameter. #if DUMPOUTPUT - LOGPRINTF(stderr, "Update_%ls\n", parameter.Name().c_str()); + LOGPRINTF(stderr, "Update_%ls\n", parameter.Uid().c_str()); #endif #ifdef _DEBUG if (HasNan(smoothedGradientValue, "TrainOneEpoch/UpdateWeights/Learner::Update(): ")) - LogicError("%ls has NaNs in smoothedGradient.", parameter.Name().c_str()); + LogicError("%ls has NaNs in smoothedGradient.", parameter.Uid().c_str()); #endif #if DUMPOUTPUT @@ -243,7 +243,7 @@ namespace CNTK #ifdef _DEBUG const auto& parameterValue = parameter.Value(); if (HasNan(parameterValue, "TrainOneEpoch/UpdateWeights/Learner::Update(): ")) - LogicError("%ls has NaNs in parameter values after parameter update.", parameter.Name().c_str()); + LogicError("%ls has NaNs in parameter values after parameter update.", parameter.Uid().c_str()); #endif } m_sampleCount += trainingSampleCount; @@ -286,16 +286,13 @@ namespace CNTK for (const auto& parameter : Parameters()) { - // TODO: parameter name is not guaranteed to be unique. Instead, all serializable objects - // need to expose "UId" property -- a persistent unique internal name. - // Switch to UId as soon as it's available. - if (checkpoint.Contains(parameter.Name())) + if (checkpoint.Contains(parameter.Uid())) { LogicError("Parameter names must be unique"); } const auto& smoothedGradientValue = m_smoothedGradientValues.at(parameter); - checkpoint[parameter.Name()] = *smoothedGradientValue; + checkpoint[parameter.Uid()] = *smoothedGradientValue; } return checkpoint; } @@ -314,24 +311,24 @@ namespace CNTK for (const auto& parameter : Parameters()) { - if (!checkpoint.Contains(parameter.Name())) + if (!checkpoint.Contains(parameter.Uid())) { - LogicError("Checkpoint does not contain state for parameter %ls", parameter.Name().c_str()); + LogicError("Checkpoint does not contain state for parameter %ls", parameter.Uid().c_str()); } const auto& smoothedGradientValue = m_smoothedGradientValues.at(parameter); - const NDArrayView& checkpointedValue = checkpoint[parameter.Name()].Value(); + const NDArrayView& checkpointedValue = checkpoint[parameter.Uid()].Value(); if (smoothedGradientValue->GetDataType() != checkpointedValue.GetDataType()) { LogicError("A value restored from a checkpoint for the smoothed gradient data type for parameter %ls does not match the expected value", - parameter.Name().c_str()); + parameter.Uid().c_str()); } if (smoothedGradientValue->Shape() != checkpointedValue.Shape()) { LogicError("A value restored from a checkpoint for the smoothed gradient shape for parameter %ls does not match the expected value", - parameter.Name().c_str()); + parameter.Uid().c_str()); } smoothedGradientValue->CopyFrom(checkpointedValue); diff --git a/Source/CNTKv2LibraryDll/NDArrayView.cpp b/Source/CNTKv2LibraryDll/NDArrayView.cpp index 4dc973ebc..5d7c61398 100644 --- a/Source/CNTKv2LibraryDll/NDArrayView.cpp +++ b/Source/CNTKv2LibraryDll/NDArrayView.cpp @@ -212,9 +212,9 @@ namespace CNTK return const_cast*>(GetTensorView()); } - NDArrayViewPtr NDArrayView::DeepClone(bool readOnly/* = false*/) const + NDArrayViewPtr NDArrayView::DeepClone(const DeviceDescriptor& device, bool readOnly/* = false*/) const { - NDArrayViewPtr newView = MakeSharedObject(this->GetDataType(), this->GetStorageFormat(), this->Shape(), this->Device()); + NDArrayViewPtr newView = MakeSharedObject(this->GetDataType(), this->GetStorageFormat(), this->Shape(), device); switch (m_dataType) { case DataType::Float: @@ -242,7 +242,7 @@ namespace CNTK void NDArrayView::CopyFrom(const NDArrayView& source) { - if (source.Shape() != Shape()) + if ((source.Shape() != Shape()) && (AsTensorShape(source.Shape()) != AsTensorShape(Shape()))) InvalidArgument("NDArrayView::CopyFrom: The 'source' view's shape must be same as the shape of this NDArrayView"); if (IsReadOnly()) diff --git a/Source/CNTKv2LibraryDll/Trainer.cpp b/Source/CNTKv2LibraryDll/Trainer.cpp index 0a501aa87..f37697688 100644 --- a/Source/CNTKv2LibraryDll/Trainer.cpp +++ b/Source/CNTKv2LibraryDll/Trainer.cpp @@ -6,6 +6,7 @@ #include "stdafx.h" #include "CNTKLibrary.h" #include "Utils.h" +#include "Function.h" namespace CNTK { @@ -160,8 +161,6 @@ namespace CNTK void Trainer::SaveCheckpoint(const std::wstring& modelFilePath) { - LogicError("Trainer checkpointing is currently not supported"); - SaveAsLegacyModel(m_combinedTrainingFunction, modelFilePath); if (m_parameterLearners.size() > 1) @@ -176,38 +175,73 @@ namespace CNTK void Trainer::RestoreFromCheckpoint(const std::wstring& modelFilePath) { - LogicError("Trainer checkpointing is currently not supported"); - auto firstLearner = *(m_parameterLearners.begin()); - auto device = firstLearner->Parameters().begin()->Value()->Device(); - // Determine the indices of the model, loss and evaluation functions in the combined function's outputs to properly restore them after loading the model - auto findFunctionIdx = [](const FunctionPtr& combinedFunction, const FunctionPtr& functionToFind) { - if (functionToFind->Outputs().size() != 1) - LogicError("The trainer's model, loss or evaluation functions should have onlye 1 output"); + auto loadedModelFunction = LoadLegacyModel(m_combinedTrainingFunction->Outputs()[0].GetDataType(), modelFilePath, DeviceDescriptor::CPUDevice()); - auto combinedOutputs = combinedFunction->Outputs(); - auto functionToFindOutput = functionToFind->Output(); - for (size_t i = 0; i < combinedOutputs.size(); ++i) + // TODO: Make sure that the loaded model is the same as the trainer's model through UID matching in the V2 format + // TODO: For V1 format models make sure that the loaded model is isomorphic to the trainer's model + auto loadedModelLeafVariables = loadedModelFunction->Inputs(); + auto trainerModelLeafVariables = m_combinedTrainingFunction->Inputs(); + if (trainerModelLeafVariables.size() != loadedModelLeafVariables.size()) + InvalidArgument("The loaded model's leaf variables do not match the trainer model's leaf variables"); + + std::map loadedModelLeafVariablesMap; + for (auto leafVar : loadedModelLeafVariables) + loadedModelLeafVariablesMap[leafVar.Uid()] = leafVar; + + std::map trainerModelLeafVariablesMap; + for (auto leafVar : trainerModelLeafVariables) + trainerModelLeafVariablesMap[leafVar.Uid()] = leafVar; + + // Remove the initial state inputs of PastValue and FutureValue functions from the maps if they are a scalar constant + // since these are not part of the internal CNTK serialized computation graph + auto removePastAndFutureValueInitialStateScalarConstants = [](const std::unordered_set& allPrimitiveFunctions, std::map& modelLeafVariableMap) { + for (auto funcPtr : allPrimitiveFunctions) { - if (combinedOutputs[i] == functionToFindOutput) - return i; + auto primitiveFunction = dynamic_cast(funcPtr.get()); + if ((primitiveFunction->OpType() == PrimitiveOpType::PastValue) || (primitiveFunction->OpType() == PrimitiveOpType::FutureValue)) + { + auto initialStateInput = primitiveFunction->Inputs()[1]; + if (initialStateInput.IsConstant() && (initialStateInput.Shape().TotalSize() == 1)) + modelLeafVariableMap.erase(initialStateInput.Uid()); + } } - - LogicError("Specified model/loss/evaluation function not found within the trainer's combined root function"); }; - size_t modelFunctionIdx = findFunctionIdx(m_combinedTrainingFunction, m_model); - size_t lossFunctionIndex = findFunctionIdx(m_combinedTrainingFunction, m_lossFunction); - size_t evaluationFunctionIdx = SIZE_MAX; - if (m_evaluationFunction) - evaluationFunctionIdx = findFunctionIdx(m_combinedTrainingFunction, m_evaluationFunction); + auto loadedModelCompositeFunction = dynamic_cast(loadedModelFunction.get()); + removePastAndFutureValueInitialStateScalarConstants(loadedModelCompositeFunction->m_allPrimitiveFunctions, loadedModelLeafVariablesMap); - m_combinedTrainingFunction = LoadLegacyModel(m_combinedTrainingFunction->Outputs()[0].GetDataType(), modelFilePath, device); - m_model = Combine({ m_combinedTrainingFunction->Outputs()[modelFunctionIdx].Owner() }); - m_lossFunction = Combine({ m_combinedTrainingFunction->Outputs()[lossFunctionIndex].Owner() }); - if (m_evaluationFunction) - m_evaluationFunction = Combine({ m_combinedTrainingFunction->Outputs()[evaluationFunctionIdx].Owner() }); + auto trainerModelCompositeFunction = dynamic_cast(m_combinedTrainingFunction.get()); + removePastAndFutureValueInitialStateScalarConstants(trainerModelCompositeFunction->m_allPrimitiveFunctions, trainerModelLeafVariablesMap); + + // Now update the trainer's model parameters and constants with those from the loaded model + for (auto nameVarPair : trainerModelLeafVariablesMap) + { + auto trainerModelLeafVar = nameVarPair.second; + + auto areVariablesEquivalent = [](const Variable& left, const Variable& right) { + return ((left.Kind() == right.Kind()) && + ((left.Shape() == right.Shape()) || (AsTensorShape(left.Shape()) == AsTensorShape(right.Shape()))) && + (left.GetDataType() == right.GetDataType()) && + (left.DynamicAxes().size() == right.DynamicAxes().size()) && + (left.NeedsGradient() == right.NeedsGradient()) && + (left.Uid() == right.Uid()) && + (left.IsSparse() == right.IsSparse())); + }; + + auto correspondingLoadedModelVar = loadedModelLeafVariablesMap.at(trainerModelLeafVar.Uid()); + + if (!areVariablesEquivalent(correspondingLoadedModelVar, trainerModelLeafVar)) + InvalidArgument("The loaded model's leaf variables do not match the trainer model's leaf variables"); + + if (trainerModelLeafVar.IsConstant() || trainerModelLeafVar.IsParameter()) + { + auto trainerModelVarValue = trainerModelLeafVar.IsConstant() ? Constant(trainerModelLeafVar).Value() : Parameter(trainerModelLeafVar).Value(); + auto loadedModelVarValue = correspondingLoadedModelVar.IsConstant() ? Constant(correspondingLoadedModelVar).Value() : Parameter(correspondingLoadedModelVar).Value(); + trainerModelVarValue->CopyFrom(*loadedModelVarValue); + } + } if (m_parameterLearners.size() > 1) LogicError("Trainer::RestoreFromCheckpoint: Checkpointing is currently unsupported for multiple learners"); diff --git a/Source/CNTKv2LibraryDll/Utils.h b/Source/CNTKv2LibraryDll/Utils.h index f1af79d7a..ab71d162c 100644 --- a/Source/CNTKv2LibraryDll/Utils.h +++ b/Source/CNTKv2LibraryDll/Utils.h @@ -338,4 +338,45 @@ namespace CNTK { return std::pow(momentumPerSample, minibatchSize); } + + template + inline TargetElementType* Copy(const SourceElementType* src, size_t srcSize) + { + // Cast to double + TargetElementType* castValue = new TargetElementType[srcSize]; + for (size_t i = 0; i < srcSize; ++i) + castValue[i] = (TargetElementType)src[i]; + + return castValue; + } + + inline NDArrayViewPtr CloneAsDataType(const NDArrayViewPtr& source, DataType targetDataType, bool readOnly) + { + if (source->Device() != DeviceDescriptor::CPUDevice()) + LogicError("CloneAsDataType currently does not support non-CPU source NDArrayView objects"); + + auto sourceDataType = source->GetDataType(); + if (sourceDataType == targetDataType) + LogicError("CloneAsDataType: Source and target DataTypes are same"); + + if ((targetDataType != DataType::Float) && (targetDataType != DataType::Double)) + LogicError("CloneAsDataType: Only Float and Double target DataTypes are supported"); + + NDArrayViewPtr newConstantValue; + auto sourceShape = source->Shape(); + auto sourceSize = sourceShape.TotalSize(); + if (sourceDataType == DataType::Float) + { + // Cast to double + double* castValue = Copy(source->DataBuffer(), sourceSize); + newConstantValue = MakeSharedObject(sourceShape, castValue, sourceSize, DeviceDescriptor::CPUDevice(), readOnly); + } + else + { + float* castValue = Copy(source->DataBuffer(), sourceSize); + newConstantValue = MakeSharedObject(sourceShape, castValue, sourceSize, DeviceDescriptor::CPUDevice(), readOnly); + } + + return newConstantValue; + } } diff --git a/Tests/UnitTests/V2LibraryTests/Common.h b/Tests/UnitTests/V2LibraryTests/Common.h index 7713e3c46..402dc0c49 100644 --- a/Tests/UnitTests/V2LibraryTests/Common.h +++ b/Tests/UnitTests/V2LibraryTests/Common.h @@ -59,12 +59,12 @@ inline void SaveAndReloadModel(CNTK::FunctionPtr& functionPtr, const std::vector if ((_wunlink(s_tempModelPath.c_str()) != 0) && (errno != ENOENT)) std::runtime_error("Error deleting temp model file 'feedForward.net'"); - std::unordered_map inputVarNames; + std::unordered_map inputVarUids; std::unordered_map outputVarNames; for (auto varPtr : variables) { - auto retVal = varPtr->IsOutput() ? outputVarNames.insert({ varPtr->Owner()->Name(), varPtr }) : inputVarNames.insert({ varPtr->Name(), varPtr }); + auto retVal = varPtr->IsOutput() ? outputVarNames.insert({ varPtr->Owner()->Name(), varPtr }) : inputVarUids.insert({ varPtr->Uid(), varPtr }); if (!retVal.second) std::runtime_error("SaveAndReloadModel: Multiple variables having same name cannot be restored after save and reload"); } @@ -76,10 +76,10 @@ inline void SaveAndReloadModel(CNTK::FunctionPtr& functionPtr, const std::vector std::runtime_error("Error deleting temp model file 'feedForward.net'"); auto inputs = functionPtr->Inputs(); - for (auto inputVarInfo : inputVarNames) + for (auto inputVarInfo : inputVarUids) { auto newInputVar = *(std::find_if(inputs.begin(), inputs.end(), [inputVarInfo](const CNTK::Variable& var) { - return (var.Name() == inputVarInfo.first); + return (var.Uid() == inputVarInfo.first); })); *(inputVarInfo.second) = newInputVar; @@ -196,43 +196,6 @@ std::pair LSTMPComponentWithSelfStabilizat return { LSTMCell.first, LSTMCell.second }; } -inline CNTK::MinibatchSourcePtr CreateTextMinibatchSource(const std::wstring& filePath, - size_t featureDim, - size_t labelDim, - size_t epochSize, - bool isFeatureSparse = false, - bool isLabelSparse = false, - const std::wstring& featureAlias = L"", - const std::wstring& labelAlias = L"") -{ - CNTK::Dictionary featuresStreamConfig; - featuresStreamConfig[L"dim"] = featureDim; - featuresStreamConfig[L"format"] = isFeatureSparse ? L"sparse" : L"dense"; - if (!featureAlias.empty()) - featuresStreamConfig[L"alias"] = featureAlias; - - CNTK::Dictionary labelsStreamConfig; - labelsStreamConfig[L"dim"] = labelDim; - labelsStreamConfig[L"format"] = isLabelSparse ? L"sparse" : L"dense"; - if (!labelAlias.empty()) - labelsStreamConfig[L"alias"] = labelAlias; - - CNTK::Dictionary inputStreamsConfig; - inputStreamsConfig[L"features"] = featuresStreamConfig; - inputStreamsConfig[L"labels"] = labelsStreamConfig; - - CNTK::Dictionary deserializerConfiguration; - deserializerConfiguration[L"type"] = L"CNTKTextFormatDeserializer"; - deserializerConfiguration[L"file"] = filePath; - deserializerConfiguration[L"input"] = inputStreamsConfig; - - CNTK::Dictionary minibatchSourceConfiguration; - minibatchSourceConfiguration[L"epochSize"] = epochSize; - minibatchSourceConfiguration[L"deserializers"] = std::vector({ deserializerConfiguration }); - - return CreateCompositeMinibatchSource(minibatchSourceConfiguration); -} - inline std::vector GenerateSequenceLengths(size_t numSequences, size_t maxAllowedSequenceLength) { std::vector sequenceLengths(numSequences); diff --git a/Tests/UnitTests/V2LibraryTests/Seq2Seq.cpp b/Tests/UnitTests/V2LibraryTests/Seq2Seq.cpp index 37223a9ab..4807e01df 100644 --- a/Tests/UnitTests/V2LibraryTests/Seq2Seq.cpp +++ b/Tests/UnitTests/V2LibraryTests/Seq2Seq.cpp @@ -6,38 +6,6 @@ using namespace CNTK; using namespace std::placeholders; -inline CNTK::MinibatchSourcePtr CreateSeq2SeqMinibatchSource(const std::wstring& filePath, size_t inputVocabSize, size_t labelsVocabSize) -{ - CNTK::Dictionary inputStreamConfig; - inputStreamConfig[L"dim"] = inputVocabSize; - inputStreamConfig[L"format"] = L"sparse"; - inputStreamConfig[L"alias"] = L"S0"; - - CNTK::Dictionary labelsStreamConfig; - labelsStreamConfig[L"dim"] = labelsVocabSize; - labelsStreamConfig[L"format"] = L"sparse"; - labelsStreamConfig[L"alias"] = L"S1"; - - CNTK::Dictionary inputStreamsConfig; - inputStreamsConfig[L"rawInput"] = inputStreamConfig; - inputStreamsConfig[L"rawLabels"] = labelsStreamConfig; - - CNTK::Dictionary deserializerConfiguration; - deserializerConfiguration[L"type"] = L"CNTKTextFormatDeserializer"; - deserializerConfiguration[L"file"] = filePath; - deserializerConfiguration[L"input"] = inputStreamsConfig; - deserializerConfiguration[L"skipSequenceIds"] = L"false"; - deserializerConfiguration[L"maxErrors"] = (size_t)100; - deserializerConfiguration[L"traceLevel"] = (size_t)1; - deserializerConfiguration[L"chunkSizeInBytes"] = (size_t)30000000; - - CNTK::Dictionary minibatchSourceConfiguration; - minibatchSourceConfiguration[L"epochSize"] = (size_t)2000; - minibatchSourceConfiguration[L"deserializers"] = std::vector({ deserializerConfiguration }); - - return CreateCompositeMinibatchSource(minibatchSourceConfiguration); -} - void TrainSequenceToSequenceTranslator(const DeviceDescriptor& device, bool useSparseInputs, bool testSaveAndReLoad, bool testCheckpointing) { using namespace std::placeholders; @@ -150,9 +118,14 @@ void TrainSequenceToSequenceTranslator(const DeviceDescriptor& device, bool useS errs = errsVar; } - auto minibatchSource = CreateSeq2SeqMinibatchSource(L"cmudict-0.7b.train-dev-20-21.ctf", inputVocabDim, labelVocabDim); - auto rawInputStreamInfo = minibatchSource->StreamInfo(L"rawInput"); - auto rawLabelsStreamInfo = minibatchSource->StreamInfo(L"rawLabels"); + auto featureStreamName = L"rawInput"; + auto labelStreamName = L"rawLabels"; + auto minibatchSource = TextFormatMinibatchSource(L"cmudict-0.7b.train-dev-20-21.ctf", + { { featureStreamName, inputVocabDim, true, L"S0" }, {labelStreamName, labelVocabDim, true, L"S1" } }, + 5000); + + auto rawInputStreamInfo = minibatchSource->StreamInfo(featureStreamName); + auto rawLabelsStreamInfo = minibatchSource->StreamInfo(labelStreamName); double learningRatePerSample = 0.007; size_t momentumTimeConstant = 1100; @@ -164,7 +137,7 @@ void TrainSequenceToSequenceTranslator(const DeviceDescriptor& device, bool useS size_t outputFrequencyInMinibatches = 1; size_t minibatchSize = 72; size_t numMinibatchesToCheckpointAfter = testCheckpointing ? 3 : SIZE_MAX; - size_t numMinibatchesToRestoreFromCheckpointAfter = testCheckpointing ? 6 : SIZE_MAX; + size_t numMinibatchesToRestoreFromCheckpointAfter = testCheckpointing ? 20 : SIZE_MAX; bool restorationDone = false; const wchar_t* modelFile = L"seq2seq.model"; for (size_t i = 0; true; i++) @@ -172,25 +145,7 @@ void TrainSequenceToSequenceTranslator(const DeviceDescriptor& device, bool useS if (!restorationDone && (i == numMinibatchesToRestoreFromCheckpointAfter)) { printf("Trainer restoring from checkpoint at path %S\n", modelFile); - auto inputs = trainer.LossFunction()->Inputs(); - auto findInputVariableIndex = [&inputs](const Variable& inputVar) { - for (size_t i = 0; i < inputs.size(); ++i) - { - if (inputs[i] == inputVar) - return i; - } - - LogicError("Specified variable is not an input of the loss function"); - }; - - size_t rawInputIndex = findInputVariableIndex(rawInput); - size_t rawLabelsIndex = findInputVariableIndex(rawLabels); - trainer.RestoreFromCheckpoint(modelFile); - - rawInput = trainer.LossFunction()->Inputs()[rawInputIndex]; - rawLabels = trainer.LossFunction()->Inputs()[rawLabelsIndex]; - i = numMinibatchesToCheckpointAfter; restorationDone = true; } @@ -213,8 +168,6 @@ void TrainSequenceToSequenceTranslator(const DeviceDescriptor& device, bool useS void TrainSequenceToSequenceTranslator() { // TODO: Also test with sparse input variables in the graph - // TODO: Also test trainer checkpointing - - TrainSequenceToSequenceTranslator(DeviceDescriptor::GPUDevice(0), false, true, false); - TrainSequenceToSequenceTranslator(DeviceDescriptor::CPUDevice(), false, false, false); + TrainSequenceToSequenceTranslator(DeviceDescriptor::GPUDevice(0), false, false, true); + TrainSequenceToSequenceTranslator(DeviceDescriptor::CPUDevice(), false, true, false); } diff --git a/Tests/UnitTests/V2LibraryTests/SequenceClassification.cpp b/Tests/UnitTests/V2LibraryTests/SequenceClassification.cpp index ae3b68733..c8e061b24 100644 --- a/Tests/UnitTests/V2LibraryTests/SequenceClassification.cpp +++ b/Tests/UnitTests/V2LibraryTests/SequenceClassification.cpp @@ -53,7 +53,7 @@ void TrainLSTMSequenceClassifer(const DeviceDescriptor& device, bool testSaveAnd prediction = predictionVar; } - auto minibatchSource = CreateTextMinibatchSource(L"Train.ctf", inputDim, numOutputClasses, 0, true, false, L"x", L"y"); + auto minibatchSource = TextFormatMinibatchSource(L"Train.ctf", { { L"features", inputDim, true, L"x" }, { L"labels", numOutputClasses, false, L"y" } }, 0); const size_t minibatchSize = 200; auto featureStreamInfo = minibatchSource->StreamInfo(features); diff --git a/Tests/UnitTests/V2LibraryTests/TrainerTests.cpp b/Tests/UnitTests/V2LibraryTests/TrainerTests.cpp index e00b502a4..ca69186a2 100644 --- a/Tests/UnitTests/V2LibraryTests/TrainerTests.cpp +++ b/Tests/UnitTests/V2LibraryTests/TrainerTests.cpp @@ -18,7 +18,7 @@ void TrainSimpleFeedForwardClassifer(const DeviceDescriptor& device) const size_t numSweepsToTrainWith = 2; const size_t numMinibatchesToTrain = (numSamplesPerSweep * numSweepsToTrainWith) / minibatchSize; - auto minibatchSource = CreateTextMinibatchSource(L"SimpleDataTrain_cntk_text.txt", (size_t)2, (size_t)2, 0); + auto minibatchSource = TextFormatMinibatchSource(L"SimpleDataTrain_cntk_text.txt", { { L"features", inputDim }, { L"labels", numOutputClasses} }, 0); auto streamInfos = minibatchSource->StreamInfos(); auto featureStreamInfo = std::find_if(streamInfos.begin(), streamInfos.end(), [](const StreamInformation& streamInfo) { return (streamInfo.m_name == L"features"); }); auto labelStreamInfo = std::find_if(streamInfos.begin(), streamInfos.end(), [](const StreamInformation& streamInfo) { return (streamInfo.m_name == L"labels"); }); @@ -55,7 +55,7 @@ void TrainSimpleFeedForwardClassifer(const DeviceDescriptor& device) } double learningRatePerSample = 0.02; - minibatchSource = CreateTextMinibatchSource(L"SimpleDataTrain_cntk_text.txt", (size_t)2, (size_t)2, SIZE_MAX); + minibatchSource = TextFormatMinibatchSource(L"SimpleDataTrain_cntk_text.txt", { { L"features", inputDim }, { L"labels", numOutputClasses } }); Trainer trainer(classifierOutput, trainingLoss, prediction, { SGDLearner(classifierOutput->Parameters(), learningRatePerSample) }); size_t outputFrequencyInMinibatches = 20; for (size_t i = 0; i < numMinibatchesToTrain; ++i) @@ -101,11 +101,12 @@ void TrainMNISTClassifier(const DeviceDescriptor& device) const size_t numSweepsToTrainWith = 3; const size_t numMinibatchesToTrain = (numSamplesPerSweep * numSweepsToTrainWith) / minibatchSize; - auto minibatchSource = CreateTextMinibatchSource(L"Train-28x28_cntk_text.txt", (size_t)784, (size_t)10, SIZE_MAX); + auto featureStreamName = L"features"; + auto labelsStreamName = L"labels"; + auto minibatchSource = TextFormatMinibatchSource(L"Train-28x28_cntk_text.txt", { { featureStreamName, inputDim }, { labelsStreamName, numOutputClasses } }); - auto streamInfos = minibatchSource->StreamInfos(); - auto featureStreamInfo = std::find_if(streamInfos.begin(), streamInfos.end(), [](const StreamInformation& streamInfo) { return (streamInfo.m_name == L"features"); }); - auto labelStreamInfo = std::find_if(streamInfos.begin(), streamInfos.end(), [](const StreamInformation& streamInfo) { return (streamInfo.m_name == L"labels"); }); + auto featureStreamInfo = minibatchSource->StreamInfo(featureStreamName); + auto labelStreamInfo = minibatchSource->StreamInfo(labelsStreamName); double learningRatePerSample = 0.003125; Trainer trainer(classifierOutput, trainingLoss, prediction, { SGDLearner(classifierOutput->Parameters(), learningRatePerSample) }); @@ -114,7 +115,7 @@ void TrainMNISTClassifier(const DeviceDescriptor& device) for (size_t i = 0; i < numMinibatchesToTrain; ++i) { auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device); - trainer.TrainMinibatch({ { input, minibatchData[*featureStreamInfo].m_data }, { labels, minibatchData[*labelStreamInfo].m_data } }, device); + trainer.TrainMinibatch({ { input, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device); PrintTrainingProgress(trainer, i, outputFrequencyInMinibatches); } }