CNTK v2 library: Trainer checkpointing and some MinibatchSource creation helper APIs

This commit is contained in:
Amit Agarwal 2016-09-03 17:31:02 -07:00
Родитель 79ad971fe2
Коммит abf468097b
14 изменённых файлов: 315 добавлений и 223 удалений

Просмотреть файл

@ -501,10 +501,19 @@ namespace CNTK
/// Fill 'this' NDArrayView with the specified value. The underlying DataType of 'this' view should be DataType::Double.
///
CNTK_API void SetValue(double value);
///
/// Creates a new NDArrayView with newly allocated storage on the specified device and copies 'this' view's contents into the newly allocated view.
///
CNTK_API NDArrayViewPtr DeepClone(const DeviceDescriptor& device, bool readOnly = false) const;
///
/// Creates a new NDArrayView with newly allocated storage on the same device as 'this' view and copies 'this' view's contents into the newly allocated view.
///
CNTK_API NDArrayViewPtr DeepClone(bool readOnly = false) const;
inline NDArrayViewPtr DeepClone(bool readOnly = false) const
{
return DeepClone(this->Device(), readOnly);
}
///
/// Creates a new NDArrayView which is an alias of 'this' view; i.e. a new view of the same shape as 'this' over the same underlying data.
@ -854,6 +863,33 @@ namespace CNTK
Placeholder
};
inline const wchar_t* VariableKindName(VariableKind variableKind)
{
switch (variableKind)
{
case VariableKind::Input:
return L"Input";
case VariableKind::Output:
return L"Output";
case VariableKind::Parameter:
return L"Parameter";
case VariableKind::Constant:
return L"Constant";
case VariableKind::Placeholder:
return L"Placeholder";
default:
LogicError("Unknown VariableKind");
}
}
namespace Internal
{
inline std::wstring GenerateUid(VariableKind varKind)
{
return std::wstring(VariableKindName(varKind)) + std::to_wstring(Internal::NewUniqueId());
}
}
// Forward declarations
inline Variable PlaceholderVariable(const NDShape& shape, const std::vector<Axis>& dynamicAxes = Axis::DefaultInputVariableDynamicAxes);
inline Variable InputVariable(const NDShape& shape, bool isSparse, CNTK::DataType dataType, bool needsGradient, const std::wstring& name = L"", const std::vector<Axis>& dynamicAxes = Axis::DefaultInputVariableDynamicAxes);
@ -874,18 +910,18 @@ namespace CNTK
template <typename T>
friend struct std::hash;
template <typename ElementType>
friend Variable GetVariable(const Microsoft::MSR::CNTK::ComputationNodeBasePtr& node,
std::unordered_map<Microsoft::MSR::CNTK::ComputationNodeBasePtr, Variable>& nodeToVariableMap,
std::unordered_map<Variable, Variable>& placeholderReplacements,
std::unordered_set<FunctionPtr>& allPrimitiveFunctions);
private:
friend inline Variable PlaceholderVariable(const NDShape& shape, const std::vector<Axis>& dynamicAxes /*= Axis::DefaultInputVariableDynamicAxes*/);
friend inline Variable InputVariable(const NDShape& shape, bool isSparse, CNTK::DataType dataType, bool needsGradient, const std::wstring& name /*= L""*/, const std::vector<Axis>& dynamicAxes /*= Axis::DefaultInputVariableDynamicAxes*/);
friend inline Variable OutputVariable(const NDShape& shape, CNTK::DataType dataType, Function* ownerFunction, const std::vector<Axis>& dynamicAxes, const std::wstring& name /*= L""*/);
public:
///
/// Create an 'Input' Variable
///
Variable(const NDShape& shape, bool isSparse, CNTK::DataType dataType, bool needsGradient, const std::wstring& name, const std::vector<Axis>& dynamicAxes = Axis::DefaultInputVariableDynamicAxes)
: Variable(shape, VariableKind::Input, dataType, nullptr, nullptr, needsGradient, dynamicAxes, isSparse, name)
{}
public:
///
/// Create an 'Output' variable aliasing the output of the specified Function
@ -954,6 +990,11 @@ namespace CNTK
///
const std::wstring& Name() const { return m_dataFields->m_name; }
///
/// Returns the internally generated unique name of the variable
///
const std::wstring& Uid() const { return m_dataFields->m_uid; }
///
/// Returns the Function object which 'this' variable is an ouptut of.
/// Returns null when called for a Variable that is not of 'Output' VariableKind.
@ -971,8 +1012,8 @@ namespace CNTK
bool NeedsGradient() const { return m_dataFields->m_needsGradient; }
protected:
Variable(const NDShape& shape, VariableKind varType, CNTK::DataType dataType, const NDArrayViewPtr& value, bool needsGradient, const std::vector<Axis>& dynamicAxes, const std::wstring& name)
: Variable(shape, varType, dataType, nullptr, value, needsGradient, dynamicAxes, /*isSparse =*/ false, name)
Variable(const NDShape& shape, VariableKind varType, CNTK::DataType dataType, const NDArrayViewPtr& value, bool needsGradient, const std::vector<Axis>& dynamicAxes, const std::wstring& name, const std::wstring& uid)
: Variable(shape, varType, dataType, nullptr, value, needsGradient, dynamicAxes, /*isSparse =*/ false, name, uid)
{}
NDArrayViewPtr Value() const
@ -982,8 +1023,13 @@ namespace CNTK
}
private:
Variable(const NDShape& shape, VariableKind varType, CNTK::DataType dataType, Function* ownerFunction, const NDArrayViewPtr& value, bool needsGradient, const std::vector<Axis>& dynamicAxes, bool isSparse, const std::wstring& name)
: m_dataFields(MakeSharedObject<VariableFields>(shape, varType, dataType, ownerFunction, value, needsGradient, dynamicAxes, isSparse, name))
Variable(const NDShape& shape, bool isSparse, CNTK::DataType dataType, bool needsGradient, const std::wstring& name, const std::vector<Axis>& dynamicAxes, const std::wstring& uid)
: Variable(shape, VariableKind::Input, dataType, nullptr, nullptr, needsGradient, dynamicAxes, isSparse, name, uid)
{}
Variable(const NDShape& shape, VariableKind varType, CNTK::DataType dataType, Function* ownerFunction, const NDArrayViewPtr& value, bool needsGradient, const std::vector<Axis>& dynamicAxes, bool isSparse, const std::wstring& name, const std::wstring& uid)
: m_dataFields(MakeSharedObject<VariableFields>(shape, varType, dataType, ownerFunction, value, needsGradient, dynamicAxes, isSparse, name, uid))
{}
private:
@ -1001,9 +1047,10 @@ namespace CNTK
std::wstring m_name;
std::vector<Axis> m_dynamicAxes;
bool m_isSparse;
std::wstring m_uid;
VariableFields(const NDShape& shape, VariableKind varType, CNTK::DataType type, Function* ownerFunction, const NDArrayViewPtr& value, bool needsGradient, const std::vector<Axis>& dynamicAxes, bool isSparse, const std::wstring& name)
: m_shape(shape), m_varKind(varType), m_dataType(type), m_ownerFunction(ownerFunction), m_value(value), m_needsGradient(needsGradient), m_dynamicAxes(dynamicAxes), m_isSparse(isSparse), m_name(name)
VariableFields(const NDShape& shape, VariableKind varType, CNTK::DataType type, Function* ownerFunction, const NDArrayViewPtr& value, bool needsGradient, const std::vector<Axis>& dynamicAxes, bool isSparse, const std::wstring& name, const std::wstring& uid)
: m_shape(shape), m_varKind(varType), m_dataType(type), m_ownerFunction(ownerFunction), m_value(value), m_needsGradient(needsGradient), m_dynamicAxes(dynamicAxes), m_isSparse(isSparse), m_name(name), m_uid(uid)
{
if (value && (type != value->GetDataType()))
InvalidArgument("The DataType of the Parameter/Constant Variable does not match the DataType of the associated Value");
@ -1043,7 +1090,16 @@ namespace CNTK
///
inline Variable PlaceholderVariable(const NDShape& shape, const std::vector<Axis>& dynamicAxes /*= Axis::DefaultInputVariableDynamicAxes*/)
{
return Variable(shape, VariableKind::Placeholder, DataType::Unknown, nullptr, false, dynamicAxes, L"");
auto varKind = VariableKind::Placeholder;
return Variable(shape, varKind, DataType::Unknown, nullptr, false, dynamicAxes, L"", Internal::GenerateUid(varKind));
}
///
/// Create an 'Input' Variable denoting sparse data and specify if gradients are to be computed for this input
///
inline Variable InputVariable(const NDShape& shape, bool isSparse, CNTK::DataType dataType, bool needsGradient, const std::wstring& name /*= L""*/, const std::vector<Axis>& dynamicAxes /*= Axis::DefaultInputVariableDynamicAxes*/)
{
return Variable(shape, isSparse, dataType, needsGradient, name, dynamicAxes, Internal::GenerateUid(VariableKind::Input));
}
///
@ -1051,7 +1107,7 @@ namespace CNTK
///
inline Variable InputVariable(const NDShape& shape, CNTK::DataType dataType, bool needsGradient, const std::wstring& name = L"", const std::vector<Axis>& dynamicAxes = Axis::DefaultInputVariableDynamicAxes)
{
return Variable(shape, /*isSparse =*/ false, dataType, needsGradient, name, dynamicAxes);
return InputVariable(shape, /*isSparse =*/ false, dataType, needsGradient, name, dynamicAxes);
}
///
@ -1078,14 +1134,6 @@ namespace CNTK
return InputVariable(shape, dataType, L"", dynamicAxes);
}
///
/// Create an 'Input' Variable denoting sparse data and specify if gradients are to be computed for this input
///
inline Variable InputVariable(const NDShape& shape, bool isSparse, CNTK::DataType dataType, bool needsGradient, const std::wstring& name /*= L""*/, const std::vector<Axis>& dynamicAxes /*= Axis::DefaultInputVariableDynamicAxes*/)
{
return Variable(shape, isSparse, dataType, needsGradient, name, dynamicAxes);
}
///
/// Create an 'Input' Variable denoting sparse data.
///
@ -1115,7 +1163,7 @@ namespace CNTK
///
inline Variable OutputVariable(const NDShape& shape, CNTK::DataType dataType, Function* ownerFunction, const std::vector<Axis>& dynamicAxes, const std::wstring& name /*= L""*/)
{
return Variable(shape, VariableKind::Output, dataType, ownerFunction, nullptr, /*needsGradient =*/ false, dynamicAxes, /*isSparse =*/ false, name);
return Variable(shape, VariableKind::Output, dataType, ownerFunction, nullptr, /*needsGradient =*/ false, dynamicAxes, /*isSparse =*/ false, name, Internal::GenerateUid(VariableKind::Output));
}
///
@ -1126,12 +1174,18 @@ namespace CNTK
template <typename T>
friend struct std::hash;
template <typename ElementType>
friend Variable GetVariable(const Microsoft::MSR::CNTK::ComputationNodeBasePtr& node,
std::unordered_map<Microsoft::MSR::CNTK::ComputationNodeBasePtr, Variable>& nodeToVariableMap,
std::unordered_map<Variable, Variable>& placeholderReplacements,
std::unordered_set<FunctionPtr>& allPrimitiveFunctions);
public:
///
/// Construct a parameter whose initial contents are a copy of the specified 'value'
///
explicit Parameter(const NDArrayViewPtr& value, const std::wstring& name = L"")
: Variable(value->Shape(), VariableKind::Parameter, value->GetDataType(), value->DeepClone(), true, {}, name)
: Parameter(value, name, Internal::GenerateUid(VariableKind::Parameter))
{}
// TODO: Constructor to move a specified NDArrayView value
@ -1141,14 +1195,14 @@ namespace CNTK
///
template<typename ElemType>
Parameter(const NDShape& shape, ElemType initValue, const DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice(), const std::wstring& name = L"")
: Variable(shape, VariableKind::Parameter, AsDataType<ElemType>(), MakeSharedObject<NDArrayView>(initValue, shape, device), true, {}, name)
: Variable(shape, VariableKind::Parameter, AsDataType<ElemType>(), MakeSharedObject<NDArrayView>(initValue, shape, device), true, {}, name, Internal::GenerateUid(VariableKind::Parameter))
{}
///
/// Construct a constant of specified shape whose contents are initialized with the specified 'initValue'
///
Parameter(const NDShape& shape, DataType dataType, double initValue, const DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice(), const std::wstring& name = L"")
: Variable(shape, VariableKind::Parameter, dataType, MakeSharedObject<NDArrayView>(initValue, dataType, shape, device), true, {}, name)
: Variable(shape, VariableKind::Parameter, dataType, MakeSharedObject<NDArrayView>(initValue, dataType, shape, device), true, {}, name, Internal::GenerateUid(VariableKind::Parameter))
{}
///
@ -1233,6 +1287,11 @@ namespace CNTK
return Variable::Value();
}
private:
explicit Parameter(const NDArrayViewPtr& value, const std::wstring& name, const std::wstring& uid)
: Variable(value->Shape(), VariableKind::Parameter, value->GetDataType(), value->DeepClone(), true, {}, name, uid)
{}
private:
// Helper methods for Parameter construction
@ -1254,12 +1313,18 @@ namespace CNTK
template <typename T>
friend struct std::hash;
template <typename ElementType>
friend Variable GetVariable(const Microsoft::MSR::CNTK::ComputationNodeBasePtr& node,
std::unordered_map<Microsoft::MSR::CNTK::ComputationNodeBasePtr, Variable>& nodeToVariableMap,
std::unordered_map<Variable, Variable>& placeholderReplacements,
std::unordered_set<FunctionPtr>& allPrimitiveFunctions);
public:
///
/// Contruct a Constant whose initial contents are a copy of the specified value
///
Constant(const NDArrayViewPtr& value, const std::wstring& name = L"")
: Variable(value->Shape(), VariableKind::Constant, value->GetDataType(), value->DeepClone(true), false, {}, name)
: Constant(value, name, Internal::GenerateUid(VariableKind::Constant))
{}
// TODO: Constructor to move a specified NDArrayView value
@ -1269,14 +1334,14 @@ namespace CNTK
///
template<typename ElemType>
Constant(const NDShape& shape, ElemType initValue, const DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice(), const std::wstring& name = L"")
: Variable(shape, VariableKind::Constant, AsDataType<ElemType>(), MakeSharedObject<NDArrayView>(initValue, shape, device), false, {}, name)
: Variable(shape, VariableKind::Constant, AsDataType<ElemType>(), MakeSharedObject<NDArrayView>(initValue, shape, device), false, {}, name, Internal::GenerateUid(VariableKind::Constant))
{}
///
/// Construct a constant of specified shape whose contents are initialized with the specified 'initValue'
///
Constant(const NDShape& shape, DataType dataType, double initValue, const DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice(), const std::wstring& name = L"")
: Variable(shape, VariableKind::Constant, dataType, MakeSharedObject<NDArrayView>(initValue, dataType, shape, device), false, {}, name)
: Variable(shape, VariableKind::Constant, dataType, MakeSharedObject<NDArrayView>(initValue, dataType, shape, device), false, {}, name, Internal::GenerateUid(VariableKind::Constant))
{}
///
@ -1313,6 +1378,11 @@ namespace CNTK
{
return Variable::Value();
}
private:
Constant(const NDArrayViewPtr& value, const std::wstring& name, const std::wstring& uid)
: Variable(value->Shape(), VariableKind::Constant, value->GetDataType(), value->DeepClone(true), false, {}, name, uid)
{}
};
// Implementation note: The Variable type is a value type and not polymorphic in nature.
@ -1799,8 +1869,8 @@ namespace CNTK
///
inline FunctionPtr PastValue(const Variable& operand, size_t offset = 1, const std::wstring& name = L"")
{
const Variable& initialState = Constant::Scalar(0.0f);
return PastValue(operand, initialState, offset, name);
static const auto defaultInitialState = Constant::Scalar(0.0f);
return PastValue(operand, defaultInitialState, offset, name);
}
///
@ -1816,8 +1886,8 @@ namespace CNTK
///
inline FunctionPtr FutureValue(const Variable& operand, size_t offset = 1, const std::wstring& name = L"")
{
const Variable& initialState = Constant::Scalar(0.0f);
return FutureValue(operand, initialState, offset, name);
static const auto defaultInitialState = Constant::Scalar(0.0f);
return FutureValue(operand, defaultInitialState, offset, name);
}
///
@ -2722,6 +2792,53 @@ namespace CNTK
///
CNTK_API MinibatchSourcePtr CreateCompositeMinibatchSource(const Dictionary& configuration);
struct StreamConfiguration
{
StreamConfiguration(const std::wstring& streamName, size_t dim, bool isSparse = false, const std::wstring& streamAlias = L"")
: m_streamName(streamName), m_dim(dim), m_isSparse(isSparse), m_streamAlias(streamAlias)
{}
std::wstring m_streamName;
size_t m_dim;
bool m_isSparse;
std::wstring m_streamAlias;
};
///
/// Instantiate the CNTK buil-in test format minibatch source
///
inline MinibatchSourcePtr TextFormatMinibatchSource(const std::wstring& dataFilePath, const std::vector<StreamConfiguration>& streamConfigs, size_t epochSize = SIZE_MAX)
{
CNTK::Dictionary minibatchSourceConfiguration;
minibatchSourceConfiguration[L"epochSize"] = epochSize;
CNTK::Dictionary deserializerConfiguration;
deserializerConfiguration[L"type"] = L"CNTKTextFormatDeserializer";
deserializerConfiguration[L"file"] = dataFilePath;
CNTK::Dictionary inputStreamsConfig;
for (auto streamConfig : streamConfigs)
{
std::wstring streamName = streamConfig.m_streamName;
size_t streamDim = streamConfig.m_dim;
bool isSparse = streamConfig.m_isSparse;
std::wstring streamAlias = streamConfig.m_streamAlias;
CNTK::Dictionary inputStreamConfig;
inputStreamConfig[L"dim"] = streamDim;
inputStreamConfig[L"format"] = isSparse ? L"sparse" : L"dense";
if (!streamAlias.empty())
inputStreamConfig[L"alias"] = streamAlias;
inputStreamsConfig[streamName] = inputStreamConfig;
}
deserializerConfiguration[L"input"] = inputStreamsConfig;
minibatchSourceConfiguration[L"deserializers"] = std::vector<CNTK::DictionaryValue>({ deserializerConfiguration });
return CreateCompositeMinibatchSource(minibatchSourceConfiguration);
}
///
/// Compute the per dimension means and variances for each of the specified streams using data from the specified minibatchSource.
///

Просмотреть файл

@ -53,6 +53,9 @@ namespace Microsoft { namespace MSR { namespace CNTK {
template <typename ElementType>
class ComputationNode;
class ComputationNodeBase;
typedef std::shared_ptr<ComputationNodeBase> ComputationNodeBasePtr;
}}}
// TODO: The following should be reconciled with the equivalent code in the CNTK implementation
@ -195,5 +198,7 @@ namespace CNTK
CNTK_API FunctionPtr Scatter(const Variable& operand, const Variable& condition, const std::vector<Axis>& newDynamicAxes, const std::wstring& name = L"");
CNTK_API FunctionPtr Slice(const Variable& operand, const Axis& axis, int beginIndex, int endIndex, const std::wstring& name = L"");
CNTK_API FunctionPtr ReduceElements(const Variable& operand, const std::wstring& reductionOpName, const Axis& axis, const std::wstring& name = L"");
CNTK_API size_t NewUniqueId();
}
}

Просмотреть файл

@ -45,7 +45,7 @@ namespace CNTK
auto inputNodeInternalDynamicAxisName = node->GetMBLayout()->GetAxisName();
std::vector<Axis> inputVarDynamicAxes = DynamicAxesFromInternalDynamicAxisName(inputNodeInternalDynamicAxisName);
var = Variable(varShape, isSparse, AsDataType<ElementType>(), node->GetLearningRateMultiplier() != 0, node->GetName(), inputVarDynamicAxes);
var = Variable(varShape, isSparse, AsDataType<ElementType>(), node->GetLearningRateMultiplier() != 0, node->NodeName(), inputVarDynamicAxes, node->NodeName());
}
else
{
@ -58,11 +58,11 @@ namespace CNTK
bool isConstant = (node->GetLearningRateMultiplier() == 0);
auto& matrix = node->As<ComputationNode<ElementType>>()->Value();
auto tensorView = new TensorView<ElementType>(std::make_shared<Matrix<ElementType>>(matrix.AsReference()), AsTensorViewShape(node->GetSampleLayout()));
NDArrayViewPtr parameterValue = MakeSharedObject<NDArrayView>(AsDataType<ElementType>(), AsDeviceDescriptor(matrix.GetDeviceId()), AsStorageFormat(matrix.GetFormat()), varShape, false, tensorView);
NDArrayViewPtr value = MakeSharedObject<NDArrayView>(AsDataType<ElementType>(), AsDeviceDescriptor(matrix.GetDeviceId()), AsStorageFormat(matrix.GetFormat()), varShape, false, tensorView);
if (isConstant)
var = Constant(parameterValue, node->GetName());
var = Constant(value, node->NodeName(), node->NodeName());
else
var = Parameter(parameterValue, node->GetName());
var = Parameter(value, node->NodeName(), node->NodeName());
}
else
LogicError("CNTK::LoadLegacyModel: Unsupported legacy CNTK node named '%S'", node->NodeName().c_str());
@ -276,7 +276,7 @@ namespace CNTK
// Let's reorder inputVars properly since the ordering of inputs of CNTK internal ComputationNode may be different from the PrimitiveFunction inputs ordering
ReorderAsPrimitiveFunctionInputs(opType, inputVars);
FunctionPtr primitiveFunction = MakeSharedObject<PrimitiveFunction>(opType, inputVars, std::move(primitiveFunctionConfigParameters), node->GetName());
FunctionPtr primitiveFunction = MakeSharedObject<PrimitiveFunction>(opType, inputVars, std::move(primitiveFunctionConfigParameters), node->NodeName());
allPrimitiveFunctions.insert(primitiveFunction);
var = primitiveFunction->Output();
if (placeholderReplacements.find(placeholderVar) != placeholderReplacements.end())

Просмотреть файл

@ -8,6 +8,15 @@
namespace CNTK
{
namespace Internal
{
size_t NewUniqueId()
{
static std::atomic<unsigned long long> s_nextUniqueId = 0;
return s_nextUniqueId++;
}
}
/*static*/ std::atomic<bool> DeviceDescriptor::s_defaultDeviceFrozen(false);
/*static*/ std::shared_ptr<DeviceDescriptor> DeviceDescriptor::s_defaultDevice(new DeviceDescriptor(DeviceDescriptor::GPUDevice(0)));

Просмотреть файл

@ -451,15 +451,11 @@ namespace CNTK
std::shared_ptr<ComputationNode<ElementType>> computationNodePtr;
if (variable.IsParameter() || variable.IsConstant())
{
computationNodePtr = builder.CreateLearnableParameter(variable.Name(), AsTensorShape(variable.Shape()));
computationNodePtr = builder.CreateLearnableParameter(variable.Uid(), AsTensorShape(variable.Shape()));
network->InitLearnableParameters(computationNodePtr, L"fixedValue", 0); // must call this to follow protocol; can overwrite later
if (!variable.NeedsGradient())
computationNodePtr->SetLearningRateMultiplier(0.0);
// If the parameter variable does not have a name assign it the internal computation node name
if (variable.Name().empty())
variable.m_dataFields->m_name = computationNodePtr->NodeName();
NDArrayViewPtr value = variable.IsConstant() ? Constant(variable).Value() : Parameter(variable).Value();
std::shared_ptr<const Matrix<ElementType>> valueMatrix = variable.IsConstant() ? value->GetMatrix<ElementType>() : value->GetWritableMatrix<ElementType>();
if (variable.IsParameter() || (valueMatrix->GetDeviceId() == network->GetDeviceId()))
@ -493,9 +489,9 @@ namespace CNTK
network->AddNodeToNetAndAttachInputs(New<DynamicAxisNode<ElementType>>(network->GetDeviceId(), internalDynamicAxisName), {});
if (IsSparseInput(variable))
computationNodePtr = builder.CreateSparseInputNode(variable.Name(), AsTensorShape(variable.Shape()), internalDynamicAxisName);
computationNodePtr = builder.CreateSparseInputNode(variable.Uid(), AsTensorShape(variable.Shape()), internalDynamicAxisName);
else
computationNodePtr = builder.CreateInputNode(variable.Name(), AsTensorShape(variable.Shape()), internalDynamicAxisName);
computationNodePtr = builder.CreateInputNode(variable.Uid(), AsTensorShape(variable.Shape()), internalDynamicAxisName);
if (variable.NeedsGradient())
{
@ -796,36 +792,11 @@ namespace CNTK
// If the inputVar is a constant and not the right DataType lets cast it to the right type
if (inputVar.IsConstant() && (nonConstInputDataType != DataType::Unknown) && (inputVar.GetDataType() != nonConstInputDataType))
{
auto constantValue = Constant(inputVar).Value();
NDArrayView constantValueCPU(constantValue->GetDataType(), constantValue->Shape(), DeviceDescriptor::CPUDevice());
constantValueCPU.CopyFrom(*constantValue);
NDArrayViewPtr newConstantValue;
if (inputVar.GetDataType() == DataType::Float)
{
// Cast to double
const float* buffer = constantValueCPU.DataBuffer<float>();
double* castValue = new double[constantValueCPU.Shape().TotalSize()];
for (size_t i = 0; i < constantValueCPU.Shape().TotalSize(); ++i)
castValue[i] = buffer[i];
newConstantValue = MakeSharedObject<NDArrayView>(constantValue->Shape(), castValue, constantValueCPU.Shape().TotalSize(), DeviceDescriptor::CPUDevice());
}
else
{
// Cast to float
const double* buffer = constantValueCPU.DataBuffer<double>();
float* castValue = new float[constantValueCPU.Shape().TotalSize()];
for (size_t i = 0; i < constantValueCPU.Shape().TotalSize(); ++i)
castValue[i] = (float)(buffer[i]);
newConstantValue = MakeSharedObject<NDArrayView>(constantValue->Shape(), castValue, constantValueCPU.Shape().TotalSize(), DeviceDescriptor::CPUDevice());
}
auto constantValueCPU = Constant(inputVar).Value()->DeepClone(DeviceDescriptor::CPUDevice(), true);
NDArrayViewPtr newConstantValue = CloneAsDataType(constantValueCPU, nonConstInputDataType, true);
inputVar = Constant(newConstantValue);
}
auto baseNodePtr = GetNode(inputVar, network, builder, variableToNodeMap, isVariableRootMap);
inputNodes.push_back((baseNodePtr != nullptr) ? baseNodePtr->template As<ComputationNode<ElementType>>()->shared_from_this() : nullptr);
}

Просмотреть файл

@ -462,6 +462,7 @@ namespace CNTK
class CompositeFunction final : public Function
{
friend class Function;
friend class Trainer;
friend class CompositeMinibatchSource;
template <typename T, typename ...CtorArgTypes>

Просмотреть файл

@ -216,12 +216,12 @@ namespace CNTK
const auto& gradientValue = gradientValues.at(parameter);
// TODO: make this a runtime parameter.
#if DUMPOUTPUT
LOGPRINTF(stderr, "Update_%ls\n", parameter.Name().c_str());
LOGPRINTF(stderr, "Update_%ls\n", parameter.Uid().c_str());
#endif
#ifdef _DEBUG
if (HasNan(smoothedGradientValue, "TrainOneEpoch/UpdateWeights/Learner::Update(): "))
LogicError("%ls has NaNs in smoothedGradient.", parameter.Name().c_str());
LogicError("%ls has NaNs in smoothedGradient.", parameter.Uid().c_str());
#endif
#if DUMPOUTPUT
@ -243,7 +243,7 @@ namespace CNTK
#ifdef _DEBUG
const auto& parameterValue = parameter.Value();
if (HasNan(parameterValue, "TrainOneEpoch/UpdateWeights/Learner::Update(): "))
LogicError("%ls has NaNs in parameter values after parameter update.", parameter.Name().c_str());
LogicError("%ls has NaNs in parameter values after parameter update.", parameter.Uid().c_str());
#endif
}
m_sampleCount += trainingSampleCount;
@ -286,16 +286,13 @@ namespace CNTK
for (const auto& parameter : Parameters())
{
// TODO: parameter name is not guaranteed to be unique. Instead, all serializable objects
// need to expose "UId" property -- a persistent unique internal name.
// Switch to UId as soon as it's available.
if (checkpoint.Contains(parameter.Name()))
if (checkpoint.Contains(parameter.Uid()))
{
LogicError("Parameter names must be unique");
}
const auto& smoothedGradientValue = m_smoothedGradientValues.at(parameter);
checkpoint[parameter.Name()] = *smoothedGradientValue;
checkpoint[parameter.Uid()] = *smoothedGradientValue;
}
return checkpoint;
}
@ -314,24 +311,24 @@ namespace CNTK
for (const auto& parameter : Parameters())
{
if (!checkpoint.Contains(parameter.Name()))
if (!checkpoint.Contains(parameter.Uid()))
{
LogicError("Checkpoint does not contain state for parameter %ls", parameter.Name().c_str());
LogicError("Checkpoint does not contain state for parameter %ls", parameter.Uid().c_str());
}
const auto& smoothedGradientValue = m_smoothedGradientValues.at(parameter);
const NDArrayView& checkpointedValue = checkpoint[parameter.Name()].Value<NDArrayView>();
const NDArrayView& checkpointedValue = checkpoint[parameter.Uid()].Value<NDArrayView>();
if (smoothedGradientValue->GetDataType() != checkpointedValue.GetDataType())
{
LogicError("A value restored from a checkpoint for the smoothed gradient data type for parameter %ls does not match the expected value",
parameter.Name().c_str());
parameter.Uid().c_str());
}
if (smoothedGradientValue->Shape() != checkpointedValue.Shape())
{
LogicError("A value restored from a checkpoint for the smoothed gradient shape for parameter %ls does not match the expected value",
parameter.Name().c_str());
parameter.Uid().c_str());
}
smoothedGradientValue->CopyFrom(checkpointedValue);

Просмотреть файл

@ -212,9 +212,9 @@ namespace CNTK
return const_cast<TensorView<ElementType>*>(GetTensorView<ElementType>());
}
NDArrayViewPtr NDArrayView::DeepClone(bool readOnly/* = false*/) const
NDArrayViewPtr NDArrayView::DeepClone(const DeviceDescriptor& device, bool readOnly/* = false*/) const
{
NDArrayViewPtr newView = MakeSharedObject<NDArrayView>(this->GetDataType(), this->GetStorageFormat(), this->Shape(), this->Device());
NDArrayViewPtr newView = MakeSharedObject<NDArrayView>(this->GetDataType(), this->GetStorageFormat(), this->Shape(), device);
switch (m_dataType)
{
case DataType::Float:
@ -242,7 +242,7 @@ namespace CNTK
void NDArrayView::CopyFrom(const NDArrayView& source)
{
if (source.Shape() != Shape())
if ((source.Shape() != Shape()) && (AsTensorShape(source.Shape()) != AsTensorShape(Shape())))
InvalidArgument("NDArrayView::CopyFrom: The 'source' view's shape must be same as the shape of this NDArrayView");
if (IsReadOnly())

Просмотреть файл

@ -6,6 +6,7 @@
#include "stdafx.h"
#include "CNTKLibrary.h"
#include "Utils.h"
#include "Function.h"
namespace CNTK
{
@ -160,8 +161,6 @@ namespace CNTK
void Trainer::SaveCheckpoint(const std::wstring& modelFilePath)
{
LogicError("Trainer checkpointing is currently not supported");
SaveAsLegacyModel(m_combinedTrainingFunction, modelFilePath);
if (m_parameterLearners.size() > 1)
@ -176,38 +175,73 @@ namespace CNTK
void Trainer::RestoreFromCheckpoint(const std::wstring& modelFilePath)
{
LogicError("Trainer checkpointing is currently not supported");
auto firstLearner = *(m_parameterLearners.begin());
auto device = firstLearner->Parameters().begin()->Value()->Device();
// Determine the indices of the model, loss and evaluation functions in the combined function's outputs to properly restore them after loading the model
auto findFunctionIdx = [](const FunctionPtr& combinedFunction, const FunctionPtr& functionToFind) {
if (functionToFind->Outputs().size() != 1)
LogicError("The trainer's model, loss or evaluation functions should have onlye 1 output");
auto loadedModelFunction = LoadLegacyModel(m_combinedTrainingFunction->Outputs()[0].GetDataType(), modelFilePath, DeviceDescriptor::CPUDevice());
auto combinedOutputs = combinedFunction->Outputs();
auto functionToFindOutput = functionToFind->Output();
for (size_t i = 0; i < combinedOutputs.size(); ++i)
// TODO: Make sure that the loaded model is the same as the trainer's model through UID matching in the V2 format
// TODO: For V1 format models make sure that the loaded model is isomorphic to the trainer's model
auto loadedModelLeafVariables = loadedModelFunction->Inputs();
auto trainerModelLeafVariables = m_combinedTrainingFunction->Inputs();
if (trainerModelLeafVariables.size() != loadedModelLeafVariables.size())
InvalidArgument("The loaded model's leaf variables do not match the trainer model's leaf variables");
std::map<std::wstring, Variable> loadedModelLeafVariablesMap;
for (auto leafVar : loadedModelLeafVariables)
loadedModelLeafVariablesMap[leafVar.Uid()] = leafVar;
std::map<std::wstring, Variable> trainerModelLeafVariablesMap;
for (auto leafVar : trainerModelLeafVariables)
trainerModelLeafVariablesMap[leafVar.Uid()] = leafVar;
// Remove the initial state inputs of PastValue and FutureValue functions from the maps if they are a scalar constant
// since these are not part of the internal CNTK serialized computation graph
auto removePastAndFutureValueInitialStateScalarConstants = [](const std::unordered_set<FunctionPtr>& allPrimitiveFunctions, std::map<std::wstring, Variable>& modelLeafVariableMap) {
for (auto funcPtr : allPrimitiveFunctions)
{
if (combinedOutputs[i] == functionToFindOutput)
return i;
auto primitiveFunction = dynamic_cast<const PrimitiveFunction*>(funcPtr.get());
if ((primitiveFunction->OpType() == PrimitiveOpType::PastValue) || (primitiveFunction->OpType() == PrimitiveOpType::FutureValue))
{
auto initialStateInput = primitiveFunction->Inputs()[1];
if (initialStateInput.IsConstant() && (initialStateInput.Shape().TotalSize() == 1))
modelLeafVariableMap.erase(initialStateInput.Uid());
}
}
LogicError("Specified model/loss/evaluation function not found within the trainer's combined root function");
};
size_t modelFunctionIdx = findFunctionIdx(m_combinedTrainingFunction, m_model);
size_t lossFunctionIndex = findFunctionIdx(m_combinedTrainingFunction, m_lossFunction);
size_t evaluationFunctionIdx = SIZE_MAX;
if (m_evaluationFunction)
evaluationFunctionIdx = findFunctionIdx(m_combinedTrainingFunction, m_evaluationFunction);
auto loadedModelCompositeFunction = dynamic_cast<const CompositeFunction*>(loadedModelFunction.get());
removePastAndFutureValueInitialStateScalarConstants(loadedModelCompositeFunction->m_allPrimitiveFunctions, loadedModelLeafVariablesMap);
m_combinedTrainingFunction = LoadLegacyModel(m_combinedTrainingFunction->Outputs()[0].GetDataType(), modelFilePath, device);
m_model = Combine({ m_combinedTrainingFunction->Outputs()[modelFunctionIdx].Owner() });
m_lossFunction = Combine({ m_combinedTrainingFunction->Outputs()[lossFunctionIndex].Owner() });
if (m_evaluationFunction)
m_evaluationFunction = Combine({ m_combinedTrainingFunction->Outputs()[evaluationFunctionIdx].Owner() });
auto trainerModelCompositeFunction = dynamic_cast<const CompositeFunction*>(m_combinedTrainingFunction.get());
removePastAndFutureValueInitialStateScalarConstants(trainerModelCompositeFunction->m_allPrimitiveFunctions, trainerModelLeafVariablesMap);
// Now update the trainer's model parameters and constants with those from the loaded model
for (auto nameVarPair : trainerModelLeafVariablesMap)
{
auto trainerModelLeafVar = nameVarPair.second;
auto areVariablesEquivalent = [](const Variable& left, const Variable& right) {
return ((left.Kind() == right.Kind()) &&
((left.Shape() == right.Shape()) || (AsTensorShape(left.Shape()) == AsTensorShape(right.Shape()))) &&
(left.GetDataType() == right.GetDataType()) &&
(left.DynamicAxes().size() == right.DynamicAxes().size()) &&
(left.NeedsGradient() == right.NeedsGradient()) &&
(left.Uid() == right.Uid()) &&
(left.IsSparse() == right.IsSparse()));
};
auto correspondingLoadedModelVar = loadedModelLeafVariablesMap.at(trainerModelLeafVar.Uid());
if (!areVariablesEquivalent(correspondingLoadedModelVar, trainerModelLeafVar))
InvalidArgument("The loaded model's leaf variables do not match the trainer model's leaf variables");
if (trainerModelLeafVar.IsConstant() || trainerModelLeafVar.IsParameter())
{
auto trainerModelVarValue = trainerModelLeafVar.IsConstant() ? Constant(trainerModelLeafVar).Value() : Parameter(trainerModelLeafVar).Value();
auto loadedModelVarValue = correspondingLoadedModelVar.IsConstant() ? Constant(correspondingLoadedModelVar).Value() : Parameter(correspondingLoadedModelVar).Value();
trainerModelVarValue->CopyFrom(*loadedModelVarValue);
}
}
if (m_parameterLearners.size() > 1)
LogicError("Trainer::RestoreFromCheckpoint: Checkpointing is currently unsupported for multiple learners");

Просмотреть файл

@ -338,4 +338,45 @@ namespace CNTK
{
return std::pow(momentumPerSample, minibatchSize);
}
template <typename SourceElementType, typename TargetElementType>
inline TargetElementType* Copy(const SourceElementType* src, size_t srcSize)
{
// Cast to double
TargetElementType* castValue = new TargetElementType[srcSize];
for (size_t i = 0; i < srcSize; ++i)
castValue[i] = (TargetElementType)src[i];
return castValue;
}
inline NDArrayViewPtr CloneAsDataType(const NDArrayViewPtr& source, DataType targetDataType, bool readOnly)
{
if (source->Device() != DeviceDescriptor::CPUDevice())
LogicError("CloneAsDataType currently does not support non-CPU source NDArrayView objects");
auto sourceDataType = source->GetDataType();
if (sourceDataType == targetDataType)
LogicError("CloneAsDataType: Source and target DataTypes are same");
if ((targetDataType != DataType::Float) && (targetDataType != DataType::Double))
LogicError("CloneAsDataType: Only Float and Double target DataTypes are supported");
NDArrayViewPtr newConstantValue;
auto sourceShape = source->Shape();
auto sourceSize = sourceShape.TotalSize();
if (sourceDataType == DataType::Float)
{
// Cast to double
double* castValue = Copy<float, double>(source->DataBuffer<float>(), sourceSize);
newConstantValue = MakeSharedObject<NDArrayView>(sourceShape, castValue, sourceSize, DeviceDescriptor::CPUDevice(), readOnly);
}
else
{
float* castValue = Copy<double, float>(source->DataBuffer<double>(), sourceSize);
newConstantValue = MakeSharedObject<NDArrayView>(sourceShape, castValue, sourceSize, DeviceDescriptor::CPUDevice(), readOnly);
}
return newConstantValue;
}
}

Просмотреть файл

@ -59,12 +59,12 @@ inline void SaveAndReloadModel(CNTK::FunctionPtr& functionPtr, const std::vector
if ((_wunlink(s_tempModelPath.c_str()) != 0) && (errno != ENOENT))
std::runtime_error("Error deleting temp model file 'feedForward.net'");
std::unordered_map<std::wstring, CNTK::Variable*> inputVarNames;
std::unordered_map<std::wstring, CNTK::Variable*> inputVarUids;
std::unordered_map<std::wstring, CNTK::Variable*> outputVarNames;
for (auto varPtr : variables)
{
auto retVal = varPtr->IsOutput() ? outputVarNames.insert({ varPtr->Owner()->Name(), varPtr }) : inputVarNames.insert({ varPtr->Name(), varPtr });
auto retVal = varPtr->IsOutput() ? outputVarNames.insert({ varPtr->Owner()->Name(), varPtr }) : inputVarUids.insert({ varPtr->Uid(), varPtr });
if (!retVal.second)
std::runtime_error("SaveAndReloadModel: Multiple variables having same name cannot be restored after save and reload");
}
@ -76,10 +76,10 @@ inline void SaveAndReloadModel(CNTK::FunctionPtr& functionPtr, const std::vector
std::runtime_error("Error deleting temp model file 'feedForward.net'");
auto inputs = functionPtr->Inputs();
for (auto inputVarInfo : inputVarNames)
for (auto inputVarInfo : inputVarUids)
{
auto newInputVar = *(std::find_if(inputs.begin(), inputs.end(), [inputVarInfo](const CNTK::Variable& var) {
return (var.Name() == inputVarInfo.first);
return (var.Uid() == inputVarInfo.first);
}));
*(inputVarInfo.second) = newInputVar;
@ -196,43 +196,6 @@ std::pair<CNTK::FunctionPtr, CNTK::FunctionPtr> LSTMPComponentWithSelfStabilizat
return { LSTMCell.first, LSTMCell.second };
}
inline CNTK::MinibatchSourcePtr CreateTextMinibatchSource(const std::wstring& filePath,
size_t featureDim,
size_t labelDim,
size_t epochSize,
bool isFeatureSparse = false,
bool isLabelSparse = false,
const std::wstring& featureAlias = L"",
const std::wstring& labelAlias = L"")
{
CNTK::Dictionary featuresStreamConfig;
featuresStreamConfig[L"dim"] = featureDim;
featuresStreamConfig[L"format"] = isFeatureSparse ? L"sparse" : L"dense";
if (!featureAlias.empty())
featuresStreamConfig[L"alias"] = featureAlias;
CNTK::Dictionary labelsStreamConfig;
labelsStreamConfig[L"dim"] = labelDim;
labelsStreamConfig[L"format"] = isLabelSparse ? L"sparse" : L"dense";
if (!labelAlias.empty())
labelsStreamConfig[L"alias"] = labelAlias;
CNTK::Dictionary inputStreamsConfig;
inputStreamsConfig[L"features"] = featuresStreamConfig;
inputStreamsConfig[L"labels"] = labelsStreamConfig;
CNTK::Dictionary deserializerConfiguration;
deserializerConfiguration[L"type"] = L"CNTKTextFormatDeserializer";
deserializerConfiguration[L"file"] = filePath;
deserializerConfiguration[L"input"] = inputStreamsConfig;
CNTK::Dictionary minibatchSourceConfiguration;
minibatchSourceConfiguration[L"epochSize"] = epochSize;
minibatchSourceConfiguration[L"deserializers"] = std::vector<CNTK::DictionaryValue>({ deserializerConfiguration });
return CreateCompositeMinibatchSource(minibatchSourceConfiguration);
}
inline std::vector<size_t> GenerateSequenceLengths(size_t numSequences, size_t maxAllowedSequenceLength)
{
std::vector<size_t> sequenceLengths(numSequences);

Просмотреть файл

@ -6,38 +6,6 @@ using namespace CNTK;
using namespace std::placeholders;
inline CNTK::MinibatchSourcePtr CreateSeq2SeqMinibatchSource(const std::wstring& filePath, size_t inputVocabSize, size_t labelsVocabSize)
{
CNTK::Dictionary inputStreamConfig;
inputStreamConfig[L"dim"] = inputVocabSize;
inputStreamConfig[L"format"] = L"sparse";
inputStreamConfig[L"alias"] = L"S0";
CNTK::Dictionary labelsStreamConfig;
labelsStreamConfig[L"dim"] = labelsVocabSize;
labelsStreamConfig[L"format"] = L"sparse";
labelsStreamConfig[L"alias"] = L"S1";
CNTK::Dictionary inputStreamsConfig;
inputStreamsConfig[L"rawInput"] = inputStreamConfig;
inputStreamsConfig[L"rawLabels"] = labelsStreamConfig;
CNTK::Dictionary deserializerConfiguration;
deserializerConfiguration[L"type"] = L"CNTKTextFormatDeserializer";
deserializerConfiguration[L"file"] = filePath;
deserializerConfiguration[L"input"] = inputStreamsConfig;
deserializerConfiguration[L"skipSequenceIds"] = L"false";
deserializerConfiguration[L"maxErrors"] = (size_t)100;
deserializerConfiguration[L"traceLevel"] = (size_t)1;
deserializerConfiguration[L"chunkSizeInBytes"] = (size_t)30000000;
CNTK::Dictionary minibatchSourceConfiguration;
minibatchSourceConfiguration[L"epochSize"] = (size_t)2000;
minibatchSourceConfiguration[L"deserializers"] = std::vector<CNTK::DictionaryValue>({ deserializerConfiguration });
return CreateCompositeMinibatchSource(minibatchSourceConfiguration);
}
void TrainSequenceToSequenceTranslator(const DeviceDescriptor& device, bool useSparseInputs, bool testSaveAndReLoad, bool testCheckpointing)
{
using namespace std::placeholders;
@ -150,9 +118,14 @@ void TrainSequenceToSequenceTranslator(const DeviceDescriptor& device, bool useS
errs = errsVar;
}
auto minibatchSource = CreateSeq2SeqMinibatchSource(L"cmudict-0.7b.train-dev-20-21.ctf", inputVocabDim, labelVocabDim);
auto rawInputStreamInfo = minibatchSource->StreamInfo(L"rawInput");
auto rawLabelsStreamInfo = minibatchSource->StreamInfo(L"rawLabels");
auto featureStreamName = L"rawInput";
auto labelStreamName = L"rawLabels";
auto minibatchSource = TextFormatMinibatchSource(L"cmudict-0.7b.train-dev-20-21.ctf",
{ { featureStreamName, inputVocabDim, true, L"S0" }, {labelStreamName, labelVocabDim, true, L"S1" } },
5000);
auto rawInputStreamInfo = minibatchSource->StreamInfo(featureStreamName);
auto rawLabelsStreamInfo = minibatchSource->StreamInfo(labelStreamName);
double learningRatePerSample = 0.007;
size_t momentumTimeConstant = 1100;
@ -164,7 +137,7 @@ void TrainSequenceToSequenceTranslator(const DeviceDescriptor& device, bool useS
size_t outputFrequencyInMinibatches = 1;
size_t minibatchSize = 72;
size_t numMinibatchesToCheckpointAfter = testCheckpointing ? 3 : SIZE_MAX;
size_t numMinibatchesToRestoreFromCheckpointAfter = testCheckpointing ? 6 : SIZE_MAX;
size_t numMinibatchesToRestoreFromCheckpointAfter = testCheckpointing ? 20 : SIZE_MAX;
bool restorationDone = false;
const wchar_t* modelFile = L"seq2seq.model";
for (size_t i = 0; true; i++)
@ -172,25 +145,7 @@ void TrainSequenceToSequenceTranslator(const DeviceDescriptor& device, bool useS
if (!restorationDone && (i == numMinibatchesToRestoreFromCheckpointAfter))
{
printf("Trainer restoring from checkpoint at path %S\n", modelFile);
auto inputs = trainer.LossFunction()->Inputs();
auto findInputVariableIndex = [&inputs](const Variable& inputVar) {
for (size_t i = 0; i < inputs.size(); ++i)
{
if (inputs[i] == inputVar)
return i;
}
LogicError("Specified variable is not an input of the loss function");
};
size_t rawInputIndex = findInputVariableIndex(rawInput);
size_t rawLabelsIndex = findInputVariableIndex(rawLabels);
trainer.RestoreFromCheckpoint(modelFile);
rawInput = trainer.LossFunction()->Inputs()[rawInputIndex];
rawLabels = trainer.LossFunction()->Inputs()[rawLabelsIndex];
i = numMinibatchesToCheckpointAfter;
restorationDone = true;
}
@ -213,8 +168,6 @@ void TrainSequenceToSequenceTranslator(const DeviceDescriptor& device, bool useS
void TrainSequenceToSequenceTranslator()
{
// TODO: Also test with sparse input variables in the graph
// TODO: Also test trainer checkpointing
TrainSequenceToSequenceTranslator(DeviceDescriptor::GPUDevice(0), false, true, false);
TrainSequenceToSequenceTranslator(DeviceDescriptor::CPUDevice(), false, false, false);
TrainSequenceToSequenceTranslator(DeviceDescriptor::GPUDevice(0), false, false, true);
TrainSequenceToSequenceTranslator(DeviceDescriptor::CPUDevice(), false, true, false);
}

Просмотреть файл

@ -53,7 +53,7 @@ void TrainLSTMSequenceClassifer(const DeviceDescriptor& device, bool testSaveAnd
prediction = predictionVar;
}
auto minibatchSource = CreateTextMinibatchSource(L"Train.ctf", inputDim, numOutputClasses, 0, true, false, L"x", L"y");
auto minibatchSource = TextFormatMinibatchSource(L"Train.ctf", { { L"features", inputDim, true, L"x" }, { L"labels", numOutputClasses, false, L"y" } }, 0);
const size_t minibatchSize = 200;
auto featureStreamInfo = minibatchSource->StreamInfo(features);

Просмотреть файл

@ -18,7 +18,7 @@ void TrainSimpleFeedForwardClassifer(const DeviceDescriptor& device)
const size_t numSweepsToTrainWith = 2;
const size_t numMinibatchesToTrain = (numSamplesPerSweep * numSweepsToTrainWith) / minibatchSize;
auto minibatchSource = CreateTextMinibatchSource(L"SimpleDataTrain_cntk_text.txt", (size_t)2, (size_t)2, 0);
auto minibatchSource = TextFormatMinibatchSource(L"SimpleDataTrain_cntk_text.txt", { { L"features", inputDim }, { L"labels", numOutputClasses} }, 0);
auto streamInfos = minibatchSource->StreamInfos();
auto featureStreamInfo = std::find_if(streamInfos.begin(), streamInfos.end(), [](const StreamInformation& streamInfo) { return (streamInfo.m_name == L"features"); });
auto labelStreamInfo = std::find_if(streamInfos.begin(), streamInfos.end(), [](const StreamInformation& streamInfo) { return (streamInfo.m_name == L"labels"); });
@ -55,7 +55,7 @@ void TrainSimpleFeedForwardClassifer(const DeviceDescriptor& device)
}
double learningRatePerSample = 0.02;
minibatchSource = CreateTextMinibatchSource(L"SimpleDataTrain_cntk_text.txt", (size_t)2, (size_t)2, SIZE_MAX);
minibatchSource = TextFormatMinibatchSource(L"SimpleDataTrain_cntk_text.txt", { { L"features", inputDim }, { L"labels", numOutputClasses } });
Trainer trainer(classifierOutput, trainingLoss, prediction, { SGDLearner(classifierOutput->Parameters(), learningRatePerSample) });
size_t outputFrequencyInMinibatches = 20;
for (size_t i = 0; i < numMinibatchesToTrain; ++i)
@ -101,11 +101,12 @@ void TrainMNISTClassifier(const DeviceDescriptor& device)
const size_t numSweepsToTrainWith = 3;
const size_t numMinibatchesToTrain = (numSamplesPerSweep * numSweepsToTrainWith) / minibatchSize;
auto minibatchSource = CreateTextMinibatchSource(L"Train-28x28_cntk_text.txt", (size_t)784, (size_t)10, SIZE_MAX);
auto featureStreamName = L"features";
auto labelsStreamName = L"labels";
auto minibatchSource = TextFormatMinibatchSource(L"Train-28x28_cntk_text.txt", { { featureStreamName, inputDim }, { labelsStreamName, numOutputClasses } });
auto streamInfos = minibatchSource->StreamInfos();
auto featureStreamInfo = std::find_if(streamInfos.begin(), streamInfos.end(), [](const StreamInformation& streamInfo) { return (streamInfo.m_name == L"features"); });
auto labelStreamInfo = std::find_if(streamInfos.begin(), streamInfos.end(), [](const StreamInformation& streamInfo) { return (streamInfo.m_name == L"labels"); });
auto featureStreamInfo = minibatchSource->StreamInfo(featureStreamName);
auto labelStreamInfo = minibatchSource->StreamInfo(labelsStreamName);
double learningRatePerSample = 0.003125;
Trainer trainer(classifierOutput, trainingLoss, prediction, { SGDLearner(classifierOutput->Parameters(), learningRatePerSample) });
@ -114,7 +115,7 @@ void TrainMNISTClassifier(const DeviceDescriptor& device)
for (size_t i = 0; i < numMinibatchesToTrain; ++i)
{
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
trainer.TrainMinibatch({ { input, minibatchData[*featureStreamInfo].m_data }, { labels, minibatchData[*labelStreamInfo].m_data } }, device);
trainer.TrainMinibatch({ { input, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
PrintTrainingProgress(trainer, i, outputFrequencyInMinibatches);
}
}