Add recurrence and variable length sequence support for the 'Function' abstraction
This commit is contained in:
Родитель
76bf193267
Коммит
933adb275a
3
CNTK.sln
3
CNTK.sln
|
@ -1287,6 +1287,9 @@ Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CNTKv2LibraryDll", "Source\
|
|||
EndProjectSection
|
||||
EndProject
|
||||
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "LibraryTests", "Source\CNTKv2LibraryDll\LibraryTests\LibraryTests.vcxproj", "{F4CC3AB2-0DB2-4281-929A-2E68E30F0F6E}"
|
||||
ProjectSection(ProjectDependencies) = postProject
|
||||
{E5606ECE-48CA-4464-BB12-09D81D02B9EF} = {E5606ECE-48CA-4464-BB12-09D81D02B9EF}
|
||||
EndProjectSection
|
||||
EndProject
|
||||
Global
|
||||
GlobalSection(SolutionConfigurationPlatforms) = preSolution
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include <assert.h>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <string>
|
||||
|
||||
namespace CNTK
|
||||
{
|
||||
|
@ -24,6 +25,7 @@ namespace CNTK
|
|||
///
|
||||
enum class DataType
|
||||
{
|
||||
Unknown,
|
||||
Float,
|
||||
Double,
|
||||
|
||||
|
@ -413,13 +415,13 @@ namespace CNTK
|
|||
/// Static method to construct a new NDArrayView object whose contents are drawn from a normal distribution with the specified mean and standard deviation..
|
||||
///
|
||||
template <typename ElementType>
|
||||
static NDArrayViewPtr RandomNormal(const NDShape& shape, double mean, double stdDev, unsigned long seed, const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice());
|
||||
static NDArrayViewPtr RandomNormal(const NDShape& shape, double mean, double stdDev, unsigned long seed = 1, const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice());
|
||||
|
||||
///
|
||||
/// Static method to construct a new NDArrayView object whose contents are drawn from a uniform distribution in the specified value range.
|
||||
///
|
||||
template <typename ElementType>
|
||||
static NDArrayViewPtr RandomUniform(const NDShape& shape, double rangeStart, double rangeEnd, unsigned long seed, const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice());
|
||||
static NDArrayViewPtr RandomUniform(const NDShape& shape, double rangeStart, double rangeEnd, unsigned long seed = 1, const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice());
|
||||
|
||||
private:
|
||||
// Disallow copy construction and assignment
|
||||
|
@ -430,17 +432,20 @@ namespace CNTK
|
|||
NDArrayView& operator=(NDArrayView&&) = delete;
|
||||
NDArrayView(NDArrayView&& other) = delete;
|
||||
|
||||
private:
|
||||
static const size_t AutoSelectRowColSplitPoint = SIZE_MAX;
|
||||
|
||||
private:
|
||||
NDArrayView(CNTK::DataType dataType, const DeviceDescriptor& device, CNTK::StorageFormat storageType, const NDShape& viewShape, bool readOnly, void* tensorView);
|
||||
|
||||
template <typename ElementType>
|
||||
std::shared_ptr<Microsoft::MSR::CNTK::Matrix<ElementType>> GetMatrixImpl() const;
|
||||
std::shared_ptr<Microsoft::MSR::CNTK::Matrix<ElementType>> GetMatrixImpl(size_t rowColSplitPoint) const;
|
||||
|
||||
template <typename ElementType>
|
||||
std::shared_ptr<const Microsoft::MSR::CNTK::Matrix<ElementType>> GetMatrix() const;
|
||||
std::shared_ptr<const Microsoft::MSR::CNTK::Matrix<ElementType>> GetMatrix(size_t rowColSplitPoint = AutoSelectRowColSplitPoint) const;
|
||||
|
||||
template <typename ElementType>
|
||||
std::shared_ptr<Microsoft::MSR::CNTK::Matrix<ElementType>> GetWritableMatrix();
|
||||
std::shared_ptr<Microsoft::MSR::CNTK::Matrix<ElementType>> GetWritableMatrix(size_t rowColSplitPoint = AutoSelectRowColSplitPoint);
|
||||
|
||||
template <typename ElementType>
|
||||
const Microsoft::MSR::CNTK::TensorView<ElementType>* GetTensorView() const;
|
||||
|
@ -461,9 +466,91 @@ namespace CNTK
|
|||
void* m_tensorView;
|
||||
};
|
||||
|
||||
///
|
||||
/// Denotes a multi-dimensional mask used for specifying specific sections of a NDArrayView object as masked/invalid.
|
||||
/// This type denotes a view and there may be multiple simultaneous views of the data underlying a NDMask instance.
|
||||
///
|
||||
class CNTK_API NDMask final : public _Internal::_ReferenceCounter
|
||||
{
|
||||
friend class CompositeFunction;
|
||||
|
||||
public:
|
||||
///
|
||||
/// Construct a new Mask object of specified shape
|
||||
///
|
||||
explicit NDMask(const NDShape& shape, const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice());
|
||||
|
||||
///
|
||||
/// Destruct 'this' mask object
|
||||
///
|
||||
~NDMask();
|
||||
|
||||
///
|
||||
/// Mask out the specified sub-section of 'this' mask
|
||||
///
|
||||
void MaskSection(const std::vector<size_t>& sectionOffset, const NDShape& sectionShape);
|
||||
|
||||
///
|
||||
/// Clear the mask; i.e. unmask all currently masked values
|
||||
///
|
||||
void Clear();
|
||||
|
||||
///
|
||||
/// Returns the descriptor of the device that 'this' mask resides on
|
||||
///
|
||||
DeviceDescriptor Device() const
|
||||
{
|
||||
return m_device;
|
||||
}
|
||||
|
||||
///
|
||||
/// Returns the shape 'this' mask.
|
||||
///
|
||||
NDShape Shape() const
|
||||
{
|
||||
return m_maskShape;
|
||||
}
|
||||
|
||||
///
|
||||
/// Creates a new NDMask with newly allocated storage on the same device as 'this' mask and copies 'this' mask's contents into the newly allocated mask.
|
||||
///
|
||||
NDMaskPtr DeepClone() const;
|
||||
|
||||
///
|
||||
/// Creates a new NDMask which is an alias of 'this' mask.
|
||||
///
|
||||
NDMaskPtr Alias() const;
|
||||
|
||||
///
|
||||
/// Copies the contents of the 'source' NDMask to 'this' mask.
|
||||
/// The shapes of the 'source' mask and 'this' mask must be identical.
|
||||
///
|
||||
void CopyFrom(const NDMask& source);
|
||||
|
||||
private:
|
||||
NDMask(const NDShape& shape, Microsoft::MSR::CNTK::Matrix<char>* matrix);
|
||||
Microsoft::MSR::CNTK::Matrix<char>* GetMatrix() const;
|
||||
|
||||
// Disallow copy construction and assignment
|
||||
NDMask(const NDMask&) = delete;
|
||||
NDMask& operator=(const NDMask&) = delete;
|
||||
|
||||
// Disallow move construction and assignment
|
||||
NDMask& operator=(NDMask&&) = delete;
|
||||
NDMask(NDMask&& other) = delete;
|
||||
|
||||
private:
|
||||
DeviceDescriptor m_device;
|
||||
NDShape m_maskShape;
|
||||
|
||||
Microsoft::MSR::CNTK::Matrix<char>* m_matrixView;
|
||||
};
|
||||
|
||||
///
|
||||
/// Denotes a multi-dimensional array with an optional mask and is
|
||||
/// the actual data fed into or produced from a computation.
|
||||
/// Denotes a multi-dimensional array with an optional mask and is the actual data fed into or produced from a computation.
|
||||
/// The mask is typically lower dimensionailty than the data, meaning data is masked in coarse individual sample units where
|
||||
/// sample shape is data.Shape().SubShape(0, data.Shape().NumAxes() - mask.Shape().NumAxes)
|
||||
/// Also, note that the size of the data's trailing mask.Shape().NumAxes() dimensions must match the mask shape dimensions.
|
||||
///
|
||||
class CNTK_API Value : public _Internal::_ReferenceCounter
|
||||
{
|
||||
|
@ -473,6 +560,18 @@ namespace CNTK
|
|||
///
|
||||
Value(const NDArrayViewPtr& data);
|
||||
|
||||
///
|
||||
/// A multi-dimensional value with an associated mask.
|
||||
///
|
||||
Value(const NDArrayViewPtr& data, const NDMaskPtr& mask);
|
||||
|
||||
///
|
||||
/// Create a new Value object containing a collection of variable length sequences.
|
||||
/// The created Value object contains a copy of the specified 'sequences' data.
|
||||
///
|
||||
template <typename ElementType>
|
||||
static ValuePtr Create(const NDShape& sampleShape, const std::vector<const std::vector<ElementType>>& sequences, const DeviceDescriptor& device, bool readOnly = false);
|
||||
|
||||
///
|
||||
/// Destruct 'this' Value object.
|
||||
///
|
||||
|
@ -483,6 +582,27 @@ namespace CNTK
|
|||
///
|
||||
virtual NDArrayViewPtr Data() const;
|
||||
|
||||
///
|
||||
/// Returns the NDMask object corresponding to the mask associated with 'this value object.
|
||||
///
|
||||
virtual NDMaskPtr Mask() const;
|
||||
|
||||
///
|
||||
/// Creates a new Value with newly allocated storage on the same device as 'this' Value and copies 'this' Value's contents into the newly allocated Value.
|
||||
///
|
||||
virtual ValuePtr DeepClone(bool readOnly = false) const;
|
||||
|
||||
///
|
||||
/// Creates a new Value which is an alias of 'this' Value.
|
||||
///
|
||||
virtual ValuePtr Alias(bool readOnly = false) const;
|
||||
|
||||
///
|
||||
/// Copies the contents of the 'source' Value to 'this' Value.
|
||||
/// The shapes of the 'source' Value's data and mask must be identical to 'this' Value's data and mask.
|
||||
///
|
||||
virtual void CopyFrom(const Value& source);
|
||||
|
||||
private:
|
||||
// Disallow copy construction and assignment
|
||||
Value(const Value&) = delete;
|
||||
|
@ -494,8 +614,170 @@ namespace CNTK
|
|||
|
||||
private:
|
||||
NDArrayViewPtr m_data;
|
||||
NDMaskPtr m_mask;
|
||||
};
|
||||
|
||||
///
|
||||
/// Denotes an Axis of a Variable and is used for specifying the axes parameters of certain Functions such as reductions.
|
||||
/// Besides the static axes corresponding to each of the axes of the Variable's shape, Input and Output Variables
|
||||
/// also have one or more dynamic axes (corresponding to the sequence dimensions) and one implicit batch axis denoting the axes
|
||||
/// along which multiple sequences are batched in the Values corresponding to the variable when performing computations.
|
||||
///
|
||||
class Axis final
|
||||
{
|
||||
public:
|
||||
///
|
||||
/// Construct an Axis object denoting a static axis with the specified index.
|
||||
///
|
||||
Axis(size_t staticAxisIdx)
|
||||
: m_staticAxisIdx(staticAxisIdx)
|
||||
{
|
||||
const wchar_t* staticAxisNamePrefix = L"staticAxis_";
|
||||
std::wstring tempName = staticAxisNamePrefix;
|
||||
tempName = tempName + std::to_wstring(staticAxisIdx);
|
||||
m_name = CopyString(tempName.c_str());
|
||||
}
|
||||
|
||||
///
|
||||
/// Construct a dynamic axis with the specified name.
|
||||
///
|
||||
Axis(const std::wstring& name)
|
||||
: m_staticAxisIdx(SIZE_MAX)
|
||||
{
|
||||
m_name = CopyString(name.c_str());
|
||||
}
|
||||
|
||||
///
|
||||
/// Copy constructor.
|
||||
///
|
||||
Axis(const Axis& other)
|
||||
: m_staticAxisIdx(SIZE_MAX), m_name(nullptr)
|
||||
{
|
||||
*this = other;
|
||||
}
|
||||
|
||||
///
|
||||
/// Copy assignment.
|
||||
///
|
||||
Axis& operator=(const Axis& other)
|
||||
{
|
||||
if (this != &other)
|
||||
{
|
||||
delete[] m_name;
|
||||
|
||||
m_staticAxisIdx = other.m_staticAxisIdx;
|
||||
m_name = (other.m_name != nullptr) ? CopyString(other.m_name) : other.m_name;
|
||||
}
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
///
|
||||
/// Move constructor.
|
||||
///
|
||||
Axis(Axis&& other)
|
||||
: m_staticAxisIdx(SIZE_MAX), m_name(nullptr)
|
||||
{
|
||||
*this = std::move(other);
|
||||
}
|
||||
|
||||
///
|
||||
/// Move assignment.
|
||||
///
|
||||
Axis& operator=(Axis&& other)
|
||||
{
|
||||
assert(this != &other);
|
||||
|
||||
delete[] m_name;
|
||||
|
||||
m_staticAxisIdx = other.m_staticAxisIdx;
|
||||
m_name = other.m_name;
|
||||
|
||||
other.m_staticAxisIdx = SIZE_MAX;
|
||||
other.m_name = nullptr;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
///
|
||||
/// Returns a boolean indicating if 'this' Axis corresponds to a static axis
|
||||
///
|
||||
bool IsStaticAxis() const
|
||||
{
|
||||
return m_staticAxisIdx == SIZE_MAX;
|
||||
}
|
||||
|
||||
///
|
||||
/// Returns the axis index if 'this' Axis is a static axis. Throws an exception otherwise.
|
||||
///
|
||||
size_t StaticAxisIndex() const
|
||||
{
|
||||
if (!IsStaticAxis())
|
||||
InvalidArgument("Cannot query the static axis index for a non-static axis");
|
||||
|
||||
return m_staticAxisIdx;
|
||||
}
|
||||
|
||||
///
|
||||
/// Static Axis object representing the default dynamic axis.
|
||||
///
|
||||
static Axis DefaultDynamicAxis;
|
||||
|
||||
///
|
||||
/// Static Axis object representing the batch axis.
|
||||
///
|
||||
static Axis BatchAxis;
|
||||
|
||||
///
|
||||
/// Special Axis object denoting all the axes of the Value object in whose context it is used.
|
||||
///
|
||||
static Axis AllAxes;
|
||||
|
||||
///
|
||||
/// Name of 'this' axis
|
||||
///
|
||||
std::wstring Name() const
|
||||
{
|
||||
return m_name;
|
||||
}
|
||||
|
||||
///
|
||||
/// Destructor
|
||||
///
|
||||
~Axis()
|
||||
{
|
||||
delete[] m_name;
|
||||
}
|
||||
|
||||
///
|
||||
/// Default constructor; results in an invalid axis object.
|
||||
///
|
||||
Axis()
|
||||
: m_staticAxisIdx(SIZE_MAX), m_name(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
private:
|
||||
size_t m_staticAxisIdx;
|
||||
wchar_t* m_name;
|
||||
};
|
||||
|
||||
inline bool operator==(const Axis& first, const Axis& second)
|
||||
{
|
||||
if (first.IsStaticAxis() != second.IsStaticAxis())
|
||||
return false;
|
||||
|
||||
if (first.IsStaticAxis())
|
||||
return first.StaticAxisIndex() == second.StaticAxisIndex();
|
||||
else
|
||||
return first.Name() == second.Name();
|
||||
}
|
||||
|
||||
inline bool operator!=(const Axis& first, const Axis& second)
|
||||
{
|
||||
return !(first == second);
|
||||
}
|
||||
|
||||
///
|
||||
/// Enumeration type denoting the kind of a symbolic Variable object
|
||||
///
|
||||
|
@ -504,7 +786,8 @@ namespace CNTK
|
|||
Input,
|
||||
Output,
|
||||
Parameter,
|
||||
Constant
|
||||
Constant,
|
||||
Placeholder
|
||||
};
|
||||
|
||||
///
|
||||
|
@ -524,7 +807,7 @@ namespace CNTK
|
|||
/// Create an 'Input' Variable.
|
||||
///
|
||||
Variable(const NDShape& shape, CNTK::DataType dataType, const std::wstring& name = L"")
|
||||
: Variable(shape, VariableKind::Input, dataType, nullptr, nullptr, false, name)
|
||||
: Variable(shape, VariableKind::Input, dataType, nullptr, nullptr, false, { Axis::DefaultDynamicAxis }, name)
|
||||
{
|
||||
}
|
||||
|
||||
|
@ -532,15 +815,15 @@ namespace CNTK
|
|||
/// Create an 'Input' Variable and specify if gradients are to be computed for this input
|
||||
///
|
||||
Variable(const NDShape& shape, CNTK::DataType dataType, bool needsGradient, const std::wstring& name = L"")
|
||||
: Variable(shape, VariableKind::Input, dataType, nullptr, nullptr, needsGradient, name)
|
||||
: Variable(shape, VariableKind::Input, dataType, nullptr, nullptr, needsGradient, { Axis::DefaultDynamicAxis }, name)
|
||||
{
|
||||
}
|
||||
|
||||
///
|
||||
/// Create an 'Output' variable
|
||||
///
|
||||
Variable(const NDShape& shape, CNTK::DataType dataType, Function* ownerFunction, const std::wstring& name = L"")
|
||||
: Variable(shape, VariableKind::Output, dataType, ownerFunction, nullptr, false, name)
|
||||
Variable(const NDShape& shape, CNTK::DataType dataType, Function* ownerFunction, const std::vector<Axis>& dynamicAxes, const std::wstring& name = L"")
|
||||
: Variable(shape, VariableKind::Output, dataType, ownerFunction, nullptr, false, dynamicAxes, name)
|
||||
{
|
||||
}
|
||||
|
||||
|
@ -558,6 +841,14 @@ namespace CNTK
|
|||
return m_dataFields->m_shape;
|
||||
}
|
||||
|
||||
///
|
||||
/// Returns the dynamic axes of 'this' variable
|
||||
///
|
||||
std::vector<Axis> DynamicAxes() const
|
||||
{
|
||||
return m_dataFields->m_dynamicAxes;
|
||||
}
|
||||
|
||||
///
|
||||
/// Returns the VariableKind of 'this' variable
|
||||
///
|
||||
|
@ -567,7 +858,7 @@ namespace CNTK
|
|||
}
|
||||
|
||||
///
|
||||
/// Returns a boolean value indicating if 'this' variable is a parameter
|
||||
/// Returns a boolean value indicating if 'this' variable is a Parameter
|
||||
///
|
||||
bool IsParameter() const
|
||||
{
|
||||
|
@ -575,13 +866,21 @@ namespace CNTK
|
|||
}
|
||||
|
||||
///
|
||||
/// Returns a boolean value indicating if 'this' variable is a constant
|
||||
/// Returns a boolean value indicating if 'this' variable is a Constant
|
||||
///
|
||||
bool IsConstant() const
|
||||
{
|
||||
return Kind() == VariableKind::Constant;
|
||||
}
|
||||
|
||||
///
|
||||
/// Returns a boolean value indicating if 'this' variable is a Placeholder
|
||||
///
|
||||
bool IsPlaceholder() const
|
||||
{
|
||||
return Kind() == VariableKind::Placeholder;
|
||||
}
|
||||
|
||||
///
|
||||
/// Returns the name of 'this' variable
|
||||
///
|
||||
|
@ -624,9 +923,8 @@ namespace CNTK
|
|||
}
|
||||
|
||||
protected:
|
||||
// Create an 'Input' Variable
|
||||
Variable(const NDShape& shape, VariableKind varType, CNTK::DataType dataType, const NDArrayViewPtr& value, bool needsGradient, const std::wstring& name)
|
||||
: Variable(shape, varType, dataType, nullptr, value, needsGradient, name)
|
||||
Variable(const NDShape& shape, VariableKind varType, CNTK::DataType dataType, const NDArrayViewPtr& value, bool needsGradient, const std::vector<Axis>& dynamicAxes, const std::wstring& name)
|
||||
: Variable(shape, varType, dataType, nullptr, value, needsGradient, dynamicAxes, name)
|
||||
{
|
||||
}
|
||||
|
||||
|
@ -637,14 +935,14 @@ namespace CNTK
|
|||
}
|
||||
|
||||
private:
|
||||
Variable(const NDShape& shape, VariableKind varType, CNTK::DataType dataType, Function* ownerFunction, const NDArrayViewPtr& value, bool needsGradient, const std::wstring& name)
|
||||
: m_dataFields(new _VariableFields(shape, varType, dataType, ownerFunction, value, needsGradient, (name == L"") ? nullptr : name.c_str()), [](_Internal::_ReferenceCounter* ptr) { delete ptr; })
|
||||
Variable(const NDShape& shape, VariableKind varType, CNTK::DataType dataType, Function* ownerFunction, const NDArrayViewPtr& value, bool needsGradient, const std::vector<Axis>& dynamicAxes, const std::wstring& name)
|
||||
: m_dataFields(new _VariableFields(shape, varType, dataType, ownerFunction, value, needsGradient, dynamicAxes, (name == L"") ? nullptr : name.c_str()), [](_Internal::_ReferenceCounter* ptr) { delete ptr; })
|
||||
{
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
struct _VariableFields : public _Internal::_ReferenceCounter
|
||||
struct _VariableFields final : public _Internal::_ReferenceCounter
|
||||
{
|
||||
NDShape m_shape;
|
||||
VariableKind m_varKind;
|
||||
|
@ -653,9 +951,10 @@ namespace CNTK
|
|||
NDArrayViewPtr m_value;
|
||||
bool m_needsGradient;
|
||||
wchar_t* m_name;
|
||||
_Internal::_SimpleVector<Axis> m_dynamicAxes;
|
||||
|
||||
_VariableFields(const NDShape& shape, VariableKind varType, CNTK::DataType type, Function* ownerFunction, const NDArrayViewPtr& value, bool needsGradient, const wchar_t* name)
|
||||
: m_shape(shape), m_varKind(varType), m_dataType(type), m_ownerFunction(ownerFunction), m_value(value), m_needsGradient(needsGradient), m_name(nullptr)
|
||||
_VariableFields(const NDShape& shape, VariableKind varType, CNTK::DataType type, Function* ownerFunction, const NDArrayViewPtr& value, bool needsGradient, const std::vector<Axis>& dynamicAxes, const wchar_t* name)
|
||||
: m_shape(shape), m_varKind(varType), m_dataType(type), m_ownerFunction(ownerFunction), m_value(value), m_needsGradient(needsGradient), m_dynamicAxes(_Internal::_SimpleVector<Axis>::CreateSimpleVector(dynamicAxes)), m_name(nullptr)
|
||||
{
|
||||
if (name != nullptr)
|
||||
m_name = CopyString(name);
|
||||
|
@ -665,6 +964,15 @@ namespace CNTK
|
|||
{
|
||||
delete[] m_name;
|
||||
}
|
||||
|
||||
private:
|
||||
// Disallow copy construction and assignment
|
||||
_VariableFields(const _VariableFields&) = delete;
|
||||
_VariableFields& operator=(const _VariableFields& other) = delete;
|
||||
|
||||
// Disallow move construction and assignment
|
||||
_VariableFields(_VariableFields&&) = delete;
|
||||
_VariableFields& operator=(_VariableFields&&) = delete;
|
||||
};
|
||||
typedef _Internal::_ReferenceCounterSharedPtr<_VariableFields> _VariableFieldsPtr;
|
||||
|
||||
|
@ -689,7 +997,7 @@ namespace CNTK
|
|||
/// Construct a parameter whose initial contents are a copy of the specified 'value'
|
||||
///
|
||||
explicit Parameter(const NDArrayViewPtr& value, const std::wstring& name = L"")
|
||||
: Variable(value->Shape(), VariableKind::Parameter, value->DataType(), value->DeepClone(), true, name)
|
||||
: Variable(value->Shape(), VariableKind::Parameter, value->DataType(), value->DeepClone(), true, {}, name)
|
||||
{
|
||||
}
|
||||
|
||||
|
@ -700,7 +1008,7 @@ namespace CNTK
|
|||
///
|
||||
template<typename ElemType>
|
||||
Parameter(const NDShape& shape, ElemType initValue, const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice(), const std::wstring& name = L"")
|
||||
: Variable(shape, VariableKind::Parameter, GetDataType<ElemType>(), new NDArrayView(initValue, shape, device), true, name)
|
||||
: Variable(shape, VariableKind::Parameter, GetDataType<ElemType>(), new NDArrayView(initValue, shape, device), true, {}, name)
|
||||
{
|
||||
}
|
||||
|
||||
|
@ -738,7 +1046,7 @@ namespace CNTK
|
|||
/// Contruct a Constant whose initial contents are a copy of the specified value
|
||||
///
|
||||
Constant(const NDArrayViewPtr& value, const std::wstring& name = L"")
|
||||
: Variable(value->Shape(), VariableKind::Constant, value->DataType(), value->DeepClone(true), false, name)
|
||||
: Variable(value->Shape(), VariableKind::Constant, value->DataType(), value->DeepClone(true), false, {}, name)
|
||||
{
|
||||
}
|
||||
|
||||
|
@ -749,7 +1057,7 @@ namespace CNTK
|
|||
///
|
||||
template<typename ElemType>
|
||||
Constant(const NDShape& shape, ElemType initValue, const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice(), const std::wstring& name = L"")
|
||||
: Variable(shape, VariableKind::Constant, GetDataType<ElemType>(), new NDArrayView(initValue, shape, device), false, name)
|
||||
: Variable(shape, VariableKind::Constant, GetDataType<ElemType>(), new NDArrayView(initValue, shape, device), false, {}, name)
|
||||
{
|
||||
}
|
||||
|
||||
|
@ -774,6 +1082,38 @@ namespace CNTK
|
|||
|
||||
static_assert(sizeof(Constant) == sizeof(Variable), "The Constant type should not have any data fields beyond what it's base type 'Variable' has.");
|
||||
|
||||
///
|
||||
/// Denotes a Placeholder input to a Function.
|
||||
/// All placeholder inputs of a Function must be replaced with non-placeholder Variables before Forward evaluation of the Function.
|
||||
///
|
||||
class CNTK_API Placeholder final : public Variable
|
||||
{
|
||||
template <typename T>
|
||||
friend struct std::hash;
|
||||
|
||||
friend class Function;
|
||||
|
||||
public:
|
||||
///
|
||||
/// Contruct a Placeholder with the specified NDShape
|
||||
///
|
||||
explicit Placeholder(const NDShape& shape, const std::wstring& name = L"")
|
||||
: Variable(shape, VariableKind::Placeholder, DataType::Unknown, nullptr, false, {Axis::DefaultDynamicAxis}, name)
|
||||
{
|
||||
}
|
||||
|
||||
///
|
||||
/// DownCast a Variable to a Placeholder. Only allowed if the VariableKind is Placeholder and throws an exception otherwise.
|
||||
///
|
||||
explicit Placeholder(const Variable& variable)
|
||||
: Variable(variable)
|
||||
{
|
||||
if (!IsPlaceholder())
|
||||
InvalidArgument("A non-placeholder Variable being converted to a Placeholder");
|
||||
}
|
||||
};
|
||||
|
||||
static_assert(sizeof(Placeholder) == sizeof(Variable), "The Placeholder type should not have any data fields beyond what it's base type 'Variable' has.");
|
||||
#pragma warning(pop)
|
||||
}
|
||||
|
||||
|
@ -801,6 +1141,14 @@ namespace std {
|
|||
return std::hash<CNTK::Variable>()(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <> struct hash<CNTK::Placeholder>
|
||||
{
|
||||
size_t operator()(const CNTK::Placeholder& x) const
|
||||
{
|
||||
return std::hash<CNTK::Variable>()(x);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
namespace CNTK
|
||||
|
@ -822,7 +1170,7 @@ namespace CNTK
|
|||
BackPropState(const FunctionPtr& function) : m_function(function) {}
|
||||
|
||||
private:
|
||||
virtual void _ForceRTTIGeneration()
|
||||
virtual void _ForceRTTIGeneration() final
|
||||
{
|
||||
LogicError("This is an internal method that is never supposed to be called");
|
||||
}
|
||||
|
@ -832,6 +1180,9 @@ namespace CNTK
|
|||
};
|
||||
typedef _Internal::_ReferenceCounterSharedPtr<BackPropState> BackPropStatePtr;
|
||||
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable : 4251 4275)
|
||||
|
||||
///
|
||||
/// Represents a function (optionally differentiable)
|
||||
/// A Function is a symbolic entity with zero or more input arguments and one or more outputs.
|
||||
|
@ -839,8 +1190,10 @@ namespace CNTK
|
|||
/// A Function effectively is an arbitrary computation graph composed of other primitive Functions, where Variable objects
|
||||
/// for the edges and leaves of the graph.
|
||||
///
|
||||
class Function : public _Internal::_ReferenceCounter
|
||||
class CNTK_API Function : public _Internal::_ReferenceCounter
|
||||
{
|
||||
friend class CompositeFunction;
|
||||
|
||||
public:
|
||||
///
|
||||
/// Computes and stores the values of speficied variables in the 'outputs' map, using provided 'inputs' values corresponding
|
||||
|
@ -944,7 +1297,7 @@ namespace CNTK
|
|||
///
|
||||
std::vector<Variable> Inputs() const
|
||||
{
|
||||
return m_inputs;
|
||||
return _Inputs();
|
||||
}
|
||||
|
||||
///
|
||||
|
@ -971,7 +1324,9 @@ namespace CNTK
|
|||
///
|
||||
std::unordered_set<Variable> Arguments() const
|
||||
{
|
||||
return m_arguments;
|
||||
return FilteredInputs<Variable>([](const Variable& var) {
|
||||
return ((var.Kind() == VariableKind::Input) || (var.Kind() == VariableKind::Output));
|
||||
});
|
||||
}
|
||||
|
||||
///
|
||||
|
@ -979,7 +1334,9 @@ namespace CNTK
|
|||
///
|
||||
std::unordered_set<Parameter> Parameters() const
|
||||
{
|
||||
return m_parameters;
|
||||
return FilteredInputs<Parameter>([](const Variable& var) {
|
||||
return (var.Kind() == VariableKind::Parameter);
|
||||
});
|
||||
}
|
||||
|
||||
///
|
||||
|
@ -987,11 +1344,58 @@ namespace CNTK
|
|||
///
|
||||
std::unordered_set<Constant> Constants() const
|
||||
{
|
||||
return m_constants;
|
||||
return FilteredInputs<Constant>([](const Variable& var) {
|
||||
return (var.Kind() == VariableKind::Constant);
|
||||
});
|
||||
}
|
||||
|
||||
///
|
||||
/// Returns the set of all Constant variables of 'this' Function.
|
||||
///
|
||||
std::unordered_set<Placeholder> Placeholders() const
|
||||
{
|
||||
return FilteredInputs<Placeholder>([](const Variable& var) {
|
||||
return (var.Kind() == VariableKind::Placeholder);
|
||||
});
|
||||
}
|
||||
|
||||
FunctionPtr ReplacePlaceholders(const std::unordered_map<Placeholder, Variable>& placeholderReplacements)
|
||||
{
|
||||
// Cannot be called on primitive functions
|
||||
if (RootFunction() == nullptr)
|
||||
InvalidArgument("ReplacePlaceholders should never be called on primitive functions");
|
||||
|
||||
_Internal::_SimpleSet<const Function*> visitedFunctions;
|
||||
_Internal::_SimpleSet<Placeholder> replacedPlaceholders;
|
||||
auto abiSafePlaceholderReplacementsMap = _Internal::_SimpleMap<Placeholder, Variable>::CreateSimpleMap(placeholderReplacements);
|
||||
_ReplacePlaceholders(abiSafePlaceholderReplacementsMap, visitedFunctions, replacedPlaceholders);
|
||||
|
||||
if (abiSafePlaceholderReplacementsMap.Keys() != replacedPlaceholders)
|
||||
InvalidArgument("At least one of the placeholders specified for replacement was not found in the function");
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
template <typename VariableType, typename FilterFunction>
|
||||
std::unordered_set<VariableType> FilteredInputs(FilterFunction&& filterFunc) const
|
||||
{
|
||||
std::unordered_set<VariableType> filteredInputs;
|
||||
auto inputs = Inputs();
|
||||
for (size_t i = 0; i < inputs.size(); ++i)
|
||||
{
|
||||
if (filterFunc(inputs[i]))
|
||||
filteredInputs.insert(VariableType(inputs[i]));
|
||||
}
|
||||
|
||||
return filteredInputs;
|
||||
|
||||
}
|
||||
|
||||
_Internal::_SimpleVector<Variable> _Inputs() const;
|
||||
virtual void _ReplacePlaceholders(const _Internal::_SimpleMap<Placeholder, Variable>& placeholderReplacements, _Internal::_SimpleSet<const Function*>& visitedFunctions, _Internal::_SimpleSet<Placeholder>& replacedPlaceholders);
|
||||
|
||||
// Disallow copy and move construction and assignment
|
||||
Function(const Function&) = delete;
|
||||
Function(Function&&) = delete;
|
||||
|
@ -1010,21 +1414,13 @@ namespace CNTK
|
|||
{
|
||||
m_inputs.PushBack(inputs[i]);
|
||||
|
||||
switch (inputs[i].Kind())
|
||||
if ((inputs[i].Kind() != VariableKind::Input) &&
|
||||
(inputs[i].Kind() != VariableKind::Output) &&
|
||||
(inputs[i].Kind() != VariableKind::Parameter) &&
|
||||
(inputs[i].Kind() != VariableKind::Constant) &&
|
||||
(inputs[i].Kind() != VariableKind::Placeholder))
|
||||
{
|
||||
case VariableKind::Input:
|
||||
case VariableKind::Output:
|
||||
m_arguments.Insert(inputs[i]);
|
||||
break;
|
||||
case VariableKind::Parameter:
|
||||
m_parameters.Insert(*(reinterpret_cast<const Parameter*>(&(inputs[i]))));
|
||||
break;
|
||||
case VariableKind::Constant:
|
||||
m_constants.Insert(*(reinterpret_cast<const Constant*>(&(inputs[i]))));
|
||||
break;
|
||||
default:
|
||||
InvalidArgument("Function input has invalid VariableKind!");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1054,13 +1450,11 @@ namespace CNTK
|
|||
|
||||
_Internal::_SimpleVector<Variable> m_inputs;
|
||||
_Internal::_SimpleVector<Variable> m_outputs;
|
||||
_Internal::_SimpleSet<Variable> m_arguments;
|
||||
_Internal::_SimpleSet<Parameter> m_parameters;
|
||||
_Internal::_SimpleSet<Constant> m_constants;
|
||||
|
||||
FunctionPtr m_rootFunction;
|
||||
wchar_t* m_name;
|
||||
};
|
||||
#pragma warning(pop)
|
||||
|
||||
CNTK_API FunctionPtr _Combine(const _Internal::_SimpleVector<FunctionPtr>& operands, const std::wstring& name = L"");
|
||||
|
||||
|
@ -1080,6 +1474,11 @@ namespace CNTK
|
|||
///
|
||||
CNTK_API FunctionPtr Sigmoid(const Variable& operand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in elementwise tanh operation with the specified input operand.
|
||||
///
|
||||
CNTK_API FunctionPtr Tanh(const Variable& operand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in operation to compute cross-entropy with softmax for specified input operands.
|
||||
///
|
||||
|
@ -1090,6 +1489,35 @@ namespace CNTK
|
|||
///
|
||||
CNTK_API FunctionPtr PredictionError(const Variable& prediction, const Variable& labels, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in elementwise exp operation with the specified input operand.
|
||||
///
|
||||
CNTK_API FunctionPtr Exp(const Variable& operand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in operation for getting the past value along the lone dynamic axis of the specified operand.
|
||||
/// Throws an exception of the operand has more than one dynamic axis.
|
||||
///
|
||||
CNTK_API FunctionPtr PastValue(const Variable& initialState, const Variable& operand, size_t stepSize, const std::wstring& name = L"");
|
||||
|
||||
//CNTK_API FunctionPtr PastValue(const Variable& initialState, const Variable& operand, Axis axis, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in operation for getting the future value along the lone dynamic axis of the specified operand.
|
||||
/// Throws an exception of the operand has more than one dynamic axis.
|
||||
///
|
||||
CNTK_API FunctionPtr FutureValue(const Variable& initialState, const Variable& operand, size_t stepSize, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in elementwise multiplication operation on specified tensor input operands.
|
||||
///
|
||||
CNTK_API FunctionPtr ElementTimes(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in sum reduction operation on specified tensor input operand along all the axes
|
||||
///
|
||||
CNTK_API FunctionPtr ReduceSum(const Variable& operand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create a new Function instance which just combines the outputs of the specified list of 'operands' Functions such that the 'Outputs' of the
|
||||
/// new 'Function' are union of the 'Outputs' of each of the specified 'operands' Functions.
|
||||
|
|
|
@ -51,6 +51,7 @@ namespace CNTK
|
|||
{
|
||||
// Forward declarations
|
||||
class CompositeFunction;
|
||||
class Function;
|
||||
|
||||
namespace _Internal
|
||||
{
|
||||
|
@ -233,6 +234,8 @@ namespace CNTK
|
|||
template <typename ValueType>
|
||||
friend CNTK_API bool operator==(const _SimpleVector<ValueType>& first, const _SimpleVector<ValueType>& second);
|
||||
|
||||
friend class Function;
|
||||
|
||||
public:
|
||||
_SimpleVector();
|
||||
_SimpleVector(size_t numElements, const T& initVal = T());
|
||||
|
@ -361,6 +364,8 @@ namespace CNTK
|
|||
class CNTK_API _SimpleMap final
|
||||
{
|
||||
friend class CompositeFunction;
|
||||
friend class Function;
|
||||
|
||||
public:
|
||||
_SimpleMap();
|
||||
~_SimpleMap();
|
||||
|
@ -398,6 +403,9 @@ namespace CNTK
|
|||
class NDArrayView;
|
||||
typedef _Internal::_ReferenceCounterSharedPtr<NDArrayView> NDArrayViewPtr;
|
||||
|
||||
class NDMask;
|
||||
typedef _Internal::_ReferenceCounterSharedPtr<NDMask> NDMaskPtr;
|
||||
|
||||
class Value;
|
||||
typedef _Internal::_ReferenceCounterSharedPtr<Value> ValuePtr;
|
||||
|
||||
|
|
|
@ -141,6 +141,7 @@
|
|||
</ClCompile>
|
||||
<ClCompile Include="Function.cpp" />
|
||||
<ClCompile Include="NDArrayView.cpp" />
|
||||
<ClCompile Include="NDMask.cpp" />
|
||||
<ClCompile Include="stdafx.cpp">
|
||||
<PrecompiledHeader>Create</PrecompiledHeader>
|
||||
</ClCompile>
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
<ClCompile Include="Function.cpp" />
|
||||
<ClCompile Include="Variable.cpp" />
|
||||
<ClCompile Include="Utils.cpp" />
|
||||
<ClCompile Include="NDMask.cpp" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="stdafx.h" />
|
||||
|
|
|
@ -13,4 +13,8 @@ namespace CNTK
|
|||
// TODO: Should return the global default device.
|
||||
return GPUDevice(0);
|
||||
}
|
||||
|
||||
/*static*/ Axis Axis::DefaultDynamicAxis = Axis(L"defaultDynamicAxis");
|
||||
/*static*/ Axis Axis::BatchAxis = Axis(L"batchAxis");
|
||||
/*static*/ Axis Axis::AllAxes = Axis(L"allAxes");
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#include "ComputationNetworkBuilder.h"
|
||||
#include "Utils.h"
|
||||
#include "ComputationNode.h"
|
||||
#include "ReshapingNodes.h"
|
||||
|
||||
using namespace Microsoft::MSR::CNTK;
|
||||
|
||||
|
@ -16,12 +17,44 @@ bool g_shareNodeValueMatrices = true;
|
|||
|
||||
namespace CNTK
|
||||
{
|
||||
_Internal::_SimpleVector<Variable> Function::_Inputs() const
|
||||
{
|
||||
const CompositeFunction* compositeFunction = dynamic_cast<const CompositeFunction*>(this);
|
||||
if (compositeFunction == nullptr)
|
||||
return m_inputs;
|
||||
else
|
||||
return _Internal::_SimpleVector<Variable>::CreateSimpleVector(compositeFunction->DetermineInputs());
|
||||
}
|
||||
|
||||
/*virtual*/ void Function::_ReplacePlaceholders(const _Internal::_SimpleMap<Placeholder, Variable>& placeholderReplacements, _Internal::_SimpleSet<const Function*>& visitedFunctions, _Internal::_SimpleSet<Placeholder>& replacedPlaceholders)
|
||||
{
|
||||
visitedFunctions.Insert(this);
|
||||
|
||||
for (auto iter = m_inputs.m_vector->begin(); iter != m_inputs.m_vector->end(); ++iter)
|
||||
{
|
||||
if (iter->IsPlaceholder())
|
||||
{
|
||||
Placeholder placeholder(*iter);
|
||||
if (placeholderReplacements.Contains(placeholder))
|
||||
{
|
||||
*iter = placeholderReplacements[placeholder];
|
||||
replacedPlaceholders.Insert(placeholder);
|
||||
}
|
||||
}
|
||||
else if ((iter->Kind() == VariableKind::Output) && !visitedFunctions.Contains(iter->Owner()))
|
||||
iter->Owner()->_ReplacePlaceholders(placeholderReplacements, visitedFunctions, replacedPlaceholders);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
/*static*/ ComputationNodeBasePtr CompositeFunction::GetNode(const Variable& variable, ComputationNetworkBuilder<ElementType>& builder, std::unordered_map<Variable, ComputationNodeBasePtr>& variableToNodeMap, std::unordered_map<Variable, bool>& isVariableRootMap)
|
||||
/*static*/ ComputationNodeBasePtr CompositeFunction::GetNode(const Variable& variable, Microsoft::MSR::CNTK::ComputationNetworkPtr& network, ComputationNetworkBuilder<ElementType>& builder, std::unordered_map<Variable, ComputationNodeBasePtr>& variableToNodeMap, std::unordered_map<Variable, bool>& isVariableRootMap)
|
||||
{
|
||||
if (variableToNodeMap.find(variable) != variableToNodeMap.end())
|
||||
return variableToNodeMap[variable];
|
||||
|
||||
// Lets add a null entry in the map for this variable, to break infinite recursion when processing recurrent graphs
|
||||
variableToNodeMap[variable] = nullptr;
|
||||
|
||||
std::shared_ptr<ComputationNode<ElementType>> computationNodePtr;
|
||||
if (variable.IsParameter() || variable.IsConstant())
|
||||
{
|
||||
|
@ -47,7 +80,7 @@ namespace CNTK
|
|||
else
|
||||
{
|
||||
assert(variable.Kind() == VariableKind::Output);
|
||||
computationNodePtr = GetOutputVariableNode(variable, builder, variableToNodeMap, isVariableRootMap)->As<ComputationNode<ElementType>>()->shared_from_this();
|
||||
computationNodePtr = GetOutputVariableNode(variable, network, builder, variableToNodeMap, isVariableRootMap)->As<ComputationNode<ElementType>>()->shared_from_this();
|
||||
}
|
||||
|
||||
variableToNodeMap[variable] = computationNodePtr;
|
||||
|
@ -56,9 +89,10 @@ namespace CNTK
|
|||
}
|
||||
|
||||
template <typename ElementType>
|
||||
/*static*/ ComputationNodeBasePtr CompositeFunction::GetOutputVariableNode(const Variable& variable, ComputationNetworkBuilder<ElementType>& builder, std::unordered_map<Variable, ComputationNodeBasePtr>& variableToNodeMap, std::unordered_map<Variable, bool>& isVariableRootMap)
|
||||
/*static*/ ComputationNodeBasePtr CompositeFunction::GetOutputVariableNode(const Variable& variable, Microsoft::MSR::CNTK::ComputationNetworkPtr& network, ComputationNetworkBuilder<ElementType>& builder, std::unordered_map<Variable, ComputationNodeBasePtr>& variableToNodeMap, std::unordered_map<Variable, bool>& isVariableRootMap)
|
||||
{
|
||||
assert(variable.Kind() == VariableKind::Output);
|
||||
|
||||
Function* function = variable.Owner();
|
||||
ComputationNodeBasePtr computationNodePtr;
|
||||
if (dynamic_cast<PrimitiveFunction*>(function) != nullptr)
|
||||
|
@ -67,11 +101,15 @@ namespace CNTK
|
|||
|
||||
// Create the nodes corresponding to the inputs
|
||||
auto functionInputs = primitiveFunction->Inputs();
|
||||
std::shared_ptr<ComputationNode<ElementType>> input0Node = GetNode(functionInputs[0], builder, variableToNodeMap, isVariableRootMap)->As<ComputationNode<ElementType>>()->shared_from_this();
|
||||
auto input0BaseNodePtr = GetNode(functionInputs[0], network, builder, variableToNodeMap, isVariableRootMap);
|
||||
std::shared_ptr<ComputationNode<ElementType>> input0Node = (input0BaseNodePtr != nullptr) ? input0BaseNodePtr->As<ComputationNode<ElementType>>()->shared_from_this() : nullptr;
|
||||
|
||||
std::shared_ptr<ComputationNode<ElementType>> input1Node;
|
||||
if (functionInputs.size() > 1)
|
||||
input1Node = GetNode(functionInputs[1], builder, variableToNodeMap, isVariableRootMap)->As<ComputationNode<ElementType>>()->shared_from_this();
|
||||
{
|
||||
auto input1BaseNodePtr = GetNode(functionInputs[1], network, builder, variableToNodeMap, isVariableRootMap);
|
||||
input1Node = (input1BaseNodePtr != nullptr) ? input1BaseNodePtr->As<ComputationNode<ElementType>>()->shared_from_this() : nullptr;
|
||||
}
|
||||
|
||||
PrimitiveOpType op = primitiveFunction->OpType();
|
||||
switch (op)
|
||||
|
@ -86,15 +124,60 @@ namespace CNTK
|
|||
case PrimitiveOpType::Sigmoid:
|
||||
computationNodePtr = builder.Sigmoid(input0Node, function->Name());
|
||||
break;
|
||||
case PrimitiveOpType::Tanh:
|
||||
computationNodePtr = builder.Tanh(input0Node, function->Name());
|
||||
break;
|
||||
case PrimitiveOpType::CrossEntropyWithSoftmax:
|
||||
computationNodePtr = builder.CrossEntropyWithSoftmax(input1Node, input0Node, function->Name());
|
||||
break;
|
||||
case PrimitiveOpType::PredictionError:
|
||||
computationNodePtr = builder.ErrorPrediction(input1Node, input0Node, function->Name());
|
||||
break;
|
||||
case PrimitiveOpType::Exp:
|
||||
computationNodePtr = builder.Exp(input0Node, function->Name());
|
||||
break;
|
||||
case PrimitiveOpType::PastValue:
|
||||
case PrimitiveOpType::FutureValue:
|
||||
{
|
||||
Variable initialStateVar = functionInputs[0];
|
||||
Variable inputOperandVar = functionInputs[1];
|
||||
// TODO: Current we only support a scalar initial state
|
||||
if (!initialStateVar.IsConstant() || (initialStateVar.Shape().NumAxes() > 0))
|
||||
LogicError("Currently PastValue/FutureValue Function only supports scalar initial state");
|
||||
|
||||
// TODO: We currently only support input operand with 1 static axis for PastValue/FutureValue
|
||||
if (inputOperandVar.Shape().NumAxes() != 1)
|
||||
LogicError("Currently PastValue/FutureValue Function only supports input operand with 1 static axis");
|
||||
|
||||
// TODO: We currently only support input operand with 1 dynamic axis for PastValue/FutureValue
|
||||
if (inputOperandVar.DynamicAxes().size() != 1)
|
||||
LogicError("Currently PastValue/FutureValue Function only supports input operand with 1 dynamic axis");
|
||||
|
||||
// Get the intial state of the PastValue/FutureValue operation
|
||||
ElementType initStateValue;
|
||||
NDArrayView tempView({}, &initStateValue, 1, DeviceDescriptor::CPUDevice());
|
||||
tempView.CopyFrom(*Constant(initialStateVar).Value());
|
||||
|
||||
if (op == PrimitiveOpType::PastValue)
|
||||
computationNodePtr = builder.PastValue(input1Node, (float)initStateValue, inputOperandVar.Shape()[0], primitiveFunction->FunctionConfig()[L"stepSize"].GetValue<size_t>(), function->Name());
|
||||
else
|
||||
computationNodePtr = builder.FutureValue(input1Node, (float)initStateValue, inputOperandVar.Shape()[0], primitiveFunction->FunctionConfig()[L"stepSize"].GetValue<size_t>(), function->Name());
|
||||
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::ElementTimes:
|
||||
computationNodePtr = builder.ElementTimes(input0Node, input1Node, function->Name());
|
||||
break;
|
||||
case PrimitiveOpType::ReduceSum:
|
||||
{
|
||||
// TODO: Use the new ReduceElements node instead of the legacy SumElements node for reduction. Currently ReduceElements has incorrect MBLayout inference.
|
||||
//computationNodePtr = network->AddNodeToNetAndAttachInputs(New<ReduceElementsNode<ElementType>>(network->GetDeviceId(), function->Name(), L"Sum", 0), { input0Node });
|
||||
computationNodePtr = builder.Sum(input0Node, function->Name());
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::Combine:
|
||||
for (size_t i = 0; i < functionInputs.size(); ++i)
|
||||
GetNode(functionInputs[i], builder, variableToNodeMap, isVariableRootMap);
|
||||
GetNode(functionInputs[i], network, builder, variableToNodeMap, isVariableRootMap);
|
||||
|
||||
computationNodePtr = variableToNodeMap[variable];
|
||||
|
||||
|
@ -152,7 +235,7 @@ namespace CNTK
|
|||
std::vector<ComputationNodeBasePtr> forwardRootNodes;
|
||||
for (size_t i = 0; i < rootFunctionOutputs.size(); ++i)
|
||||
{
|
||||
auto currentRootNode = GetNode(rootFunctionOutputs[i], builder, m_variableToNodeMap, m_isVariableRootMap);
|
||||
auto currentRootNode = GetNode(rootFunctionOutputs[i], m_computationNetwork, builder, m_variableToNodeMap, m_isVariableRootMap);
|
||||
forwardRootNodes.push_back(currentRootNode);
|
||||
|
||||
if (backpropRoots.Contains(rootFunctionOutputs[i]))
|
||||
|
@ -168,6 +251,26 @@ namespace CNTK
|
|||
|
||||
m_currentBackpropRoots = backpropRoots;
|
||||
|
||||
// In case of recurrence, the inputs of some of the ComputationNodes are not attached due to cycles.
|
||||
// Now attach those after we have created all ComputationNodes in the network
|
||||
for (auto iter = m_variableToNodeMap.begin(); iter != m_variableToNodeMap.end(); ++iter)
|
||||
{
|
||||
auto currentComputationNodeInputs = iter->second->GetInputs();
|
||||
|
||||
// TODO: Can any node other than a non PastValue/FutureValue Function have a null input attached after the first pass is finished?
|
||||
if (std::find(currentComputationNodeInputs.begin(), currentComputationNodeInputs.end(), nullptr) != currentComputationNodeInputs.end())
|
||||
{
|
||||
// We found a null input; this variable must correspond to a PastValue or FutureValue function
|
||||
const PrimitiveFunction* primitiveFunc = dynamic_cast<const PrimitiveFunction*>(iter->first.Owner().GetPtr());
|
||||
if ((primitiveFunc == nullptr) || ((primitiveFunc->OpType() != PrimitiveOpType::PastValue) && (primitiveFunc->OpType() != PrimitiveOpType::FutureValue)))
|
||||
InvalidArgument("Invalid Function graph detected; recurrence found at a Function that is not a PastValue/FutureValue function");
|
||||
|
||||
// The 2nd input of the PastValue/FutureValue function denotes the recurrent input
|
||||
auto actualInput = m_variableToNodeMap[primitiveFunc->Inputs()[1]];
|
||||
iter->second->AttachInputs({ actualInput });
|
||||
}
|
||||
}
|
||||
|
||||
m_computationNetwork->CompileNetwork();
|
||||
|
||||
// Verify that the shapes of the output Variables that we computed match the corresponding nodes in the ComputationNetwork
|
||||
|
@ -193,42 +296,195 @@ namespace CNTK
|
|||
return m_computationNetwork;
|
||||
}
|
||||
|
||||
/*static*/ void CompositeFunction::CopyNDArrayViewToComputationNodeValue(const NDArrayViewPtr& arrayView, ComputationNodeBasePtr node)
|
||||
template <typename ElementType>
|
||||
/*static*/ std::pair<std::shared_ptr<const Matrix<ElementType>>, MBLayoutPtr> CompositeFunction::GetCNTKImplMatrixAndMBLayoutFromValueObject(Variable var, const ValuePtr& value)
|
||||
{
|
||||
switch (arrayView->DataType())
|
||||
if (var.DataType() != value->Data()->DataType())
|
||||
LogicError("The Variable's DataType %s does not match the corresponding Value's DataType %s", DataTypeName(var.DataType()), DataTypeName(value->Data()->DataType()));
|
||||
|
||||
if (GetDataType<ElementType>() != value->Data()->DataType())
|
||||
LogicError("The specified ElementType %s does not match the DataType %s", typeid(ElementType).name(), DataTypeName(value->Data()->DataType()));
|
||||
|
||||
if (value->Data()->Shape().NumAxes() == var.Shape().NumAxes())
|
||||
return{ value->Data()->GetMatrix<ElementType>(), nullptr };
|
||||
|
||||
if (value->Data()->Shape().NumAxes() != (var.Shape().NumAxes() + var.DynamicAxes().size() + 1))
|
||||
InvalidArgument("Value's number of axes should be larger than the Variable's number of axes by 1 + number of dynamic axes");
|
||||
|
||||
if (var.DynamicAxes().size() > 1)
|
||||
LogicError("More than one dynamic axis for a variable is currently unsupported");
|
||||
|
||||
size_t maxNumTimeSteps = value->Data()->Shape()[var.Shape().NumAxes()];
|
||||
size_t numSequences = value->Data()->Shape()[var.Shape().NumAxes() + 1];
|
||||
|
||||
auto mask = value->Mask();
|
||||
if ((mask != nullptr) && ((var.Shape().NumAxes() + mask->Shape().NumAxes()) != value->Data()->Shape().NumAxes()))
|
||||
InvalidArgument("Invalid Value object; the sum of the #axes of the mask and data does not equal the Variable's number of axes by 1 + number of dynamic axes");
|
||||
|
||||
if ((numSequences == 1) || (maxNumTimeSteps == 1))
|
||||
{
|
||||
case DataType::Float:
|
||||
{
|
||||
auto& nodeData = node->As<ComputationNode<float>>()->Value();
|
||||
nodeData.AssignValuesOf(*(arrayView->GetMatrix<float>()));
|
||||
break;
|
||||
// The data need not be shuffled
|
||||
std::shared_ptr<const Matrix<ElementType>> matrixData = value->Data()->GetMatrix<ElementType>(var.Shape().NumAxes());
|
||||
auto layout = std::make_shared<MBLayout>();
|
||||
if (maxNumTimeSteps == 1)
|
||||
layout->InitAsFrameMode(numSequences);
|
||||
else
|
||||
{
|
||||
layout->Init(1, maxNumTimeSteps);
|
||||
layout->AddSequence(0, 0, 0, maxNumTimeSteps);
|
||||
}
|
||||
|
||||
return{ matrixData , layout};
|
||||
}
|
||||
case DataType::Double:
|
||||
else
|
||||
{
|
||||
auto& nodeData = node->As<ComputationNode<double>>()->Value();
|
||||
nodeData.AssignValuesOf(*(arrayView->GetMatrix<double>()));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
LogicError("Unsupported DataType %s", DataTypeName(arrayView->DataType()));
|
||||
break;
|
||||
std::vector<size_t> sequenceLengths(numSequences, maxNumTimeSteps);
|
||||
if (mask != nullptr)
|
||||
{
|
||||
// Determine the sequence lengths from the mask
|
||||
std::unique_ptr<char[]> maskData(mask->GetMatrix()->CopyToArray());
|
||||
for (size_t i = 0; i < numSequences; ++i)
|
||||
{
|
||||
size_t currentSequenceLength = 0;
|
||||
bool currentSequenceEndAlreadyFound = false;
|
||||
for (size_t j = 0; j < maxNumTimeSteps; ++j)
|
||||
{
|
||||
if (maskData[(i * maxNumTimeSteps) + j] == 1)
|
||||
{
|
||||
if (currentSequenceEndAlreadyFound)
|
||||
InvalidArgument("Invalid Value object; only trailing steps of a sequence can be masked");
|
||||
|
||||
currentSequenceLength++;
|
||||
}
|
||||
else
|
||||
{
|
||||
currentSequenceEndAlreadyFound = true;
|
||||
}
|
||||
}
|
||||
|
||||
sequenceLengths[i] = currentSequenceLength;
|
||||
}
|
||||
}
|
||||
|
||||
// The data needs to be rearranged since CNTK requires sequences to be interleaved across timesteps
|
||||
std::vector<MBLayout::SequenceInfo> sequences;
|
||||
for (size_t i = 0; i < numSequences; ++i)
|
||||
sequences.push_back({ i, SIZE_MAX, 0, sequenceLengths[i]});
|
||||
|
||||
auto layout = std::make_shared<MBLayout>();
|
||||
std::vector<std::pair<size_t, size_t>> placement;
|
||||
std::vector<size_t> rowAllocations;
|
||||
layout->InitAsPackedSequences(sequences, placement, rowAllocations);
|
||||
if (maxNumTimeSteps != layout->GetNumTimeSteps())
|
||||
LogicError("The number of time steps in the packed MBLayout does not match the longest sequence's length in the Value object");
|
||||
|
||||
if (numSequences != layout->GetNumSequences())
|
||||
LogicError("The number of sequences in the packed MBLayout does not match the sequence count in the Value object");
|
||||
|
||||
// Now generate the gather indices
|
||||
auto matrixData = std::make_shared<Matrix<ElementType>>(var.Shape().TotalSize(), layout->GetNumCols(), AsCNTKImplDeviceId(value->Data()->Device()));
|
||||
std::vector<ElementType> gatherIndicesVector(layout->GetNumCols(), 0);
|
||||
for (size_t i = 0; i < numSequences; ++i)
|
||||
{
|
||||
size_t targetParallelStreamIdx = placement[i].first;
|
||||
size_t targetStartIdxInParallelStream = placement[i].second;
|
||||
for (size_t j = 0; j < sequenceLengths[i]; ++j)
|
||||
gatherIndicesVector[((targetStartIdxInParallelStream + j) * layout->GetNumParallelSequences()) + targetParallelStreamIdx] = (ElementType)((i * maxNumTimeSteps) + j);
|
||||
}
|
||||
|
||||
auto gatherIdxMatrix = std::make_shared<Matrix<ElementType>>(1, layout->GetNumCols(), gatherIndicesVector.data(), AsCNTKImplDeviceId(value->Data()->Device()));
|
||||
matrixData->DoGatherColumnsOf(0, *gatherIdxMatrix, *(value->Data()->GetMatrix<ElementType>(var.Shape().NumAxes())), 1);
|
||||
return{ matrixData, layout };
|
||||
}
|
||||
}
|
||||
|
||||
/*static*/ void CompositeFunction::CopyNDArrayViewToComputationNodeGradient(const NDArrayViewPtr& arrayView, ComputationNodeBasePtr node)
|
||||
template <typename ElementType>
|
||||
/*static*/ ValuePtr CompositeFunction::GetValueObjectFromCNTKImplMatrixAndMBLayout(Variable var, const Matrix<ElementType>& matrix, const MBLayoutPtr& layout)
|
||||
{
|
||||
switch (arrayView->DataType())
|
||||
if (var.DynamicAxes().size() > 1)
|
||||
LogicError("More than one dynamic axis for a variable is currently unsupported");
|
||||
|
||||
if (GetDataType<ElementType>() != var.DataType())
|
||||
LogicError("The specified ElementType %s does not match the DataType %s", typeid(ElementType).name(), DataTypeName(var.DataType()));
|
||||
|
||||
if ((layout != nullptr) && (matrix.GetNumRows() != var.Shape().TotalSize()))
|
||||
LogicError("Unexpected matrix layout: The number of rows in the matrix does not match the sample size of the Variable");
|
||||
|
||||
NDShape valueDataShape = var.Shape();
|
||||
if (layout != nullptr)
|
||||
valueDataShape = valueDataShape.AppendShape({ layout->GetNumTimeSteps(), layout->GetNumSequences() });
|
||||
|
||||
// No data shuffling needed if no layout or the layout has just one time-step or just one sequence
|
||||
if ((layout == nullptr) || (layout->GetNumTimeSteps() == 1) || (layout->GetNumSequences() == 1))
|
||||
{
|
||||
case DataType::Float:
|
||||
node->As<ComputationNode<float>>()->ResetGradient(*(arrayView->GetMatrix<float>()));
|
||||
break;
|
||||
case DataType::Double:
|
||||
node->As<ComputationNode<double>>()->ResetGradient(*(arrayView->GetMatrix<double>()));
|
||||
break;
|
||||
default:
|
||||
LogicError("Unsupported DataType %s", DataTypeName(arrayView->DataType()));
|
||||
break;
|
||||
// Just create a view over the existing matrix itself
|
||||
|
||||
if ((matrix.GetFormat() != matrixFormatDense) || (matrix.GetFormat() != matrixFormatColMajor))
|
||||
LogicError("Only dense and column-major data storage format is currently supported");
|
||||
|
||||
auto tensorView = new TensorView<ElementType>(std::make_shared<Matrix<ElementType>>(matrix.AsReference()), AsTensorShape(valueDataShape));
|
||||
auto data = NDArrayViewPtr(new NDArrayView(GetDataType<ElementType>(), AsDeviceDescriptor(matrix.GetDeviceId()), StorageFormat::Dense, valueDataShape, true, tensorView), [](_ReferenceCounter* ptr) { delete ptr; });
|
||||
return ValuePtr(new Value(data), [](_ReferenceCounter* ptr) { delete ptr; });
|
||||
}
|
||||
|
||||
if (layout->GetNumCols() != matrix.GetNumCols())
|
||||
LogicError("Bad MBLayout: The number of columns in the MBLayout does not match the number of columns in the data matrix!");
|
||||
|
||||
size_t maxNumTimeSteps = layout->GetNumTimeSteps();
|
||||
size_t numSequences = layout->GetNumSequences();
|
||||
|
||||
std::vector<size_t> sequenceLengths;
|
||||
auto& layoutSequences = layout->GetAllSequences();
|
||||
for (auto iter = layoutSequences.begin(); iter != layoutSequences.end(); ++iter)
|
||||
{
|
||||
if (iter->seqId != GAP_SEQUENCE_ID)
|
||||
sequenceLengths.push_back(iter->GetNumTimeSteps());
|
||||
}
|
||||
|
||||
// Reshuffle to data to unpack and uninterleave the CNTK form data
|
||||
// Now generate the gather indices
|
||||
auto shuffledMatrixData = std::make_shared<Matrix<ElementType>>(matrix.GetNumRows(), maxNumTimeSteps * numSequences, matrix.GetDeviceId());
|
||||
|
||||
std::vector<size_t> sequencesShorterThanLongestSequence;
|
||||
for (size_t i = 0; i < numSequences; ++i)
|
||||
if (sequenceLengths[i] != maxNumTimeSteps)
|
||||
sequencesShorterThanLongestSequence.push_back(i);
|
||||
|
||||
// Set the target location of all gaps to be the last step of the first sequence that is shorter than the longest sequence in the batch
|
||||
size_t targetColIdxForInvalidColumns = sequencesShorterThanLongestSequence.empty() ? 0 : (((sequencesShorterThanLongestSequence[0] + 1) * maxNumTimeSteps) - 1);
|
||||
std::vector<ElementType> scatterIndicesVector(layout->GetNumCols(), (ElementType)targetColIdxForInvalidColumns);
|
||||
size_t i = 0;
|
||||
for (auto iter = layoutSequences.begin(); iter != layoutSequences.end(); ++iter)
|
||||
{
|
||||
if (iter->seqId != GAP_SEQUENCE_ID)
|
||||
{
|
||||
size_t targetParallelStreamIdx = iter->s;
|
||||
size_t targetStartIdxInParallelStream = iter->tBegin;
|
||||
for (size_t j = 0; j < iter->GetNumTimeSteps(); ++j)
|
||||
scatterIndicesVector[((targetStartIdxInParallelStream + j) * layout->GetNumParallelSequences()) + targetParallelStreamIdx] = (ElementType)((i * maxNumTimeSteps) + j);
|
||||
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
auto scatterIdxMatrix = std::make_shared<Matrix<ElementType>>(1, layout->GetNumCols(), scatterIndicesVector.data(), matrix.GetDeviceId());
|
||||
shuffledMatrixData->DoScatterColumnsOf(0, *scatterIdxMatrix, matrix, 1);
|
||||
|
||||
// Create the mask if needed
|
||||
NDMaskPtr mask;
|
||||
if (!sequencesShorterThanLongestSequence.empty())
|
||||
{
|
||||
mask = NDMaskPtr(new NDMask({ maxNumTimeSteps, numSequences }, AsDeviceDescriptor(matrix.GetDeviceId())), [](_ReferenceCounter* ptr) { delete ptr; });
|
||||
for (size_t i = 0; i < sequencesShorterThanLongestSequence.size(); ++i)
|
||||
{
|
||||
size_t shorterSequenceIdx = sequencesShorterThanLongestSequence[i];
|
||||
mask->MaskSection({ sequenceLengths[shorterSequenceIdx], shorterSequenceIdx }, { NDShape::InferredDimension, 1 });
|
||||
}
|
||||
}
|
||||
|
||||
auto tensorView = new TensorView<ElementType>(shuffledMatrixData, AsTensorShape(valueDataShape));
|
||||
auto data = NDArrayViewPtr(new NDArrayView(GetDataType<ElementType>(), AsDeviceDescriptor(matrix.GetDeviceId()), StorageFormat::Dense, valueDataShape, true, tensorView), [](_ReferenceCounter* ptr) { delete ptr; });
|
||||
return ValuePtr(new Value(data, mask), [](_ReferenceCounter* ptr) { delete ptr; });
|
||||
}
|
||||
|
||||
void CompositeFunction::PopulateNetworkInputs(const _Internal::_SimpleMap<Variable, const ValuePtr>& arguments)
|
||||
|
@ -245,17 +501,36 @@ namespace CNTK
|
|||
inputNodes.push_back(argumentComputationNode);
|
||||
|
||||
ValuePtr argumentValue = arguments[*iter];
|
||||
CopyNDArrayViewToComputationNodeValue(argumentValue->Data(), argumentComputationNode);
|
||||
|
||||
// TODO: No sequence support for now
|
||||
// The number of axes for argument Value can be at most 1 larger than the number of axes of the variable's shape
|
||||
if (argumentValue->Data()->Shape().NumAxes() != (iter->Shape().NumAxes() + 1))
|
||||
InvalidArgument("Argument value's number of axes should be 1 larger than the argument variable's number of axes");
|
||||
MBLayoutPtr layout;
|
||||
switch (argumentValue->Data()->DataType())
|
||||
{
|
||||
case DataType::Float:
|
||||
{
|
||||
auto CNTKMatrixAndMBLayout = GetCNTKImplMatrixAndMBLayoutFromValueObject<float>(*iter, argumentValue);
|
||||
layout = CNTKMatrixAndMBLayout.second;
|
||||
|
||||
size_t numSamples = argumentValue->Data()->Shape()[iter->Shape().NumAxes()];
|
||||
auto& nodeData = argumentComputationNode->As<ComputationNode<float>>()->Value();
|
||||
nodeData.AssignValuesOf(*CNTKMatrixAndMBLayout.first);
|
||||
break;
|
||||
}
|
||||
case DataType::Double:
|
||||
{
|
||||
auto CNTKMatrixAndMBLayout = GetCNTKImplMatrixAndMBLayoutFromValueObject<double>(*iter, argumentValue);
|
||||
layout = CNTKMatrixAndMBLayout.second;
|
||||
|
||||
argumentComputationNode->GetMBLayout()->InitAsFrameMode(numSamples);
|
||||
auto& nodeData = argumentComputationNode->As<ComputationNode<double>>()->Value();
|
||||
nodeData.AssignValuesOf(*CNTKMatrixAndMBLayout.first);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
LogicError("Unsupported DataType %s", DataTypeName(argumentValue->Data()->DataType()));
|
||||
break;
|
||||
}
|
||||
|
||||
argumentComputationNode->GetMBLayout()->CopyFrom(layout);
|
||||
}
|
||||
|
||||
m_computationNetwork->BumpEvalTimeStamp(inputNodes);
|
||||
}
|
||||
|
||||
|
@ -270,33 +545,35 @@ namespace CNTK
|
|||
InvalidArgument("Gradients cannot be specified for a Variable that is not an Output of the Function");
|
||||
|
||||
auto outputComputationNode = m_variableToNodeMap[iter->first];
|
||||
auto nodeLayout = outputComputationNode->GetMBLayout();
|
||||
|
||||
ValuePtr gradientValue = iter->second;
|
||||
CopyNDArrayViewToComputationNodeGradient(gradientValue->Data(), outputComputationNode);
|
||||
}
|
||||
}
|
||||
|
||||
/*static*/ void CompositeFunction::CopyComputationNodeDataToNDArrayView(const Microsoft::MSR::CNTK::ComputationNodeBasePtr& node, NDArrayViewPtr arrayView, bool copyGradient)
|
||||
{
|
||||
switch (arrayView->DataType())
|
||||
{
|
||||
case DataType::Float:
|
||||
{
|
||||
auto& outputMatrix = copyGradient ? node->As<ComputationNode<float>>()->Gradient() : node->As<ComputationNode<float>>()->Value();
|
||||
auto arrayViewMatrix = arrayView->GetWritableMatrix<float>();
|
||||
arrayViewMatrix->AssignValuesOf(outputMatrix);
|
||||
break;
|
||||
}
|
||||
case DataType::Double:
|
||||
{
|
||||
auto& outputMatrix = copyGradient ? node->As<ComputationNode<double>>()->Gradient() : node->As<ComputationNode<double>>()->Value();
|
||||
auto arrayViewMatrix = arrayView->GetWritableMatrix<double>();
|
||||
arrayViewMatrix->AssignValuesOf(outputMatrix);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
LogicError("Unsupported DataType %s", DataTypeName(arrayView->DataType()));
|
||||
break;
|
||||
MBLayoutPtr layout;
|
||||
switch (gradientValue->Data()->DataType())
|
||||
{
|
||||
case DataType::Float:
|
||||
{
|
||||
auto CNTKMatrixAndMBLayout = GetCNTKImplMatrixAndMBLayoutFromValueObject<float>(iter->first, gradientValue);
|
||||
layout = CNTKMatrixAndMBLayout.second;
|
||||
if (((layout == nullptr) != (nodeLayout == nullptr)) || ((layout != nullptr) && (*layout != *nodeLayout)))
|
||||
InvalidArgument("The layout of the specified gradient Value in incompatible with the layout of the corresponding Variable computed during Forward call");
|
||||
outputComputationNode->As<ComputationNode<float>>()->ResetGradient(*CNTKMatrixAndMBLayout.first);
|
||||
break;
|
||||
}
|
||||
case DataType::Double:
|
||||
{
|
||||
auto CNTKMatrixAndMBLayout = GetCNTKImplMatrixAndMBLayoutFromValueObject<double>(iter->first, gradientValue);
|
||||
layout = CNTKMatrixAndMBLayout.second;
|
||||
if (((layout == nullptr) != (nodeLayout == nullptr)) || ((layout != nullptr) && (*layout != *nodeLayout)))
|
||||
InvalidArgument("The layout of the specified gradient Value in incompatible with the layout of the corresponding Variable computed during Forward call");
|
||||
outputComputationNode->As<ComputationNode<double>>()->ResetGradient(*CNTKMatrixAndMBLayout.first);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
LogicError("Unsupported DataType %s", DataTypeName(gradientValue->Data()->DataType()));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -304,14 +581,17 @@ namespace CNTK
|
|||
{
|
||||
size_t outputValueNumAxes = var.Shape().NumAxes();
|
||||
if (computationNodePtr->GetMBLayout() != nullptr)
|
||||
outputValueNumAxes++;
|
||||
outputValueNumAxes += 2;
|
||||
|
||||
std::vector<size_t> outputShapeDims(outputValueNumAxes);
|
||||
for (size_t i = 0; i < var.Shape().NumAxes(); ++i)
|
||||
outputShapeDims[i] = computationNodePtr->GetSampleLayout().GetDim(i);
|
||||
|
||||
if (computationNodePtr->GetMBLayout() != nullptr)
|
||||
outputShapeDims[var.Shape().NumAxes()] = computationNodePtr->GetMBLayout()->GetNumParallelSequences();
|
||||
{
|
||||
outputShapeDims[var.Shape().NumAxes()] = computationNodePtr->GetMBLayout()->GetNumTimeSteps();
|
||||
outputShapeDims[var.Shape().NumAxes() + 1] = computationNodePtr->GetMBLayout()->GetNumSequences();
|
||||
}
|
||||
|
||||
return NDShape(outputShapeDims);
|
||||
}
|
||||
|
@ -331,12 +611,38 @@ namespace CNTK
|
|||
if (outputValuePtr->Data()->Shape() != outputShape)
|
||||
InvalidArgument("The shape %s of the specified Value object for output does not match the actual output shape %s", AsString(outputValuePtr->Data()->Shape()).c_str(), AsString(outputShape).c_str());
|
||||
}
|
||||
else
|
||||
|
||||
switch (iter->first.DataType())
|
||||
{
|
||||
outputValuePtr = new Value(new NDArrayView(iter->first.DataType(), outputShape, nullptr, 0, AsDeviceDescriptor(computationNodePtr->ValuePtr()->GetDeviceId())));
|
||||
case DataType::Float:
|
||||
{
|
||||
auto nodeValue = GetValueObjectFromCNTKImplMatrixAndMBLayout<float>(iter->first, computationNodePtr->As<ComputationNode<float>>()->Value(), computationNodePtr->GetMBLayout());
|
||||
if (outputValuePtr == nullptr)
|
||||
{
|
||||
auto data = NDArrayViewPtr(new NDArrayView(iter->first.DataType(), outputShape, nullptr, 0, AsDeviceDescriptor(computationNodePtr->ValuePtr()->GetDeviceId())), [](_ReferenceCounter* ptr) { delete ptr; });
|
||||
auto mask = (nodeValue->Mask() != nullptr) ? NDMaskPtr(new NDMask(nodeValue->Mask()->Shape(), nodeValue->Mask()->Device()), [](_ReferenceCounter* ptr) { delete ptr; }) : nullptr;
|
||||
outputValuePtr = ValuePtr(new Value(data, mask), [](_ReferenceCounter* ptr) { delete ptr; });
|
||||
}
|
||||
outputValuePtr->CopyFrom(*nodeValue);
|
||||
break;
|
||||
}
|
||||
case DataType::Double:
|
||||
{
|
||||
auto nodeValue = GetValueObjectFromCNTKImplMatrixAndMBLayout<double>(iter->first, computationNodePtr->As<ComputationNode<double>>()->Value(), computationNodePtr->GetMBLayout());
|
||||
if (outputValuePtr == nullptr)
|
||||
{
|
||||
auto data = NDArrayViewPtr(new NDArrayView(iter->first.DataType(), outputShape, nullptr, 0, AsDeviceDescriptor(computationNodePtr->ValuePtr()->GetDeviceId())), [](_ReferenceCounter* ptr) { delete ptr; });
|
||||
auto mask = (nodeValue->Mask() != nullptr) ? NDMaskPtr(new NDMask(nodeValue->Mask()->Shape(), nodeValue->Mask()->Device()), [](_ReferenceCounter* ptr) { delete ptr; }) : nullptr;
|
||||
outputValuePtr = ValuePtr(new Value(data, mask), [](_ReferenceCounter* ptr) { delete ptr; });
|
||||
}
|
||||
outputValuePtr->CopyFrom(*nodeValue);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
LogicError("Unsupported DataType %s", DataTypeName(iter->first.DataType()));
|
||||
break;
|
||||
}
|
||||
|
||||
CopyComputationNodeDataToNDArrayView(computationNodePtr, outputValuePtr->Data(), false);
|
||||
outputs[iter->first] = outputValuePtr;
|
||||
}
|
||||
}
|
||||
|
@ -365,15 +671,42 @@ namespace CNTK
|
|||
if (gradientValuePtr->Data()->Shape() != gradientShape)
|
||||
InvalidArgument("The shape %s of the specified Value object for gradient does not match the actual gradient shape %s", AsString(gradientValuePtr->Data()->Shape()).c_str(), AsString(gradientShape).c_str());
|
||||
}
|
||||
else
|
||||
{
|
||||
gradientValuePtr = new Value(new NDArrayView(iter->first.DataType(), gradientShape, nullptr, 0, AsDeviceDescriptor(computationNodePtr->ValuePtr()->GetDeviceId())));
|
||||
}
|
||||
|
||||
if (!computationNodePtr->NeedsGradient())
|
||||
LogicError("Backpropagated gradient value cannot be read from a ComputationNode that has NeedsGradient set to false");
|
||||
|
||||
CopyComputationNodeDataToNDArrayView(computationNodePtr, gradientValuePtr->Data(), true);
|
||||
switch (iter->first.DataType())
|
||||
{
|
||||
case DataType::Float:
|
||||
{
|
||||
auto nodeValue = GetValueObjectFromCNTKImplMatrixAndMBLayout<float>(iter->first, computationNodePtr->As<ComputationNode<float>>()->Gradient(), computationNodePtr->GetMBLayout());
|
||||
if (gradientValuePtr == nullptr)
|
||||
{
|
||||
auto data = NDArrayViewPtr(new NDArrayView(iter->first.DataType(), gradientShape, nullptr, 0, AsDeviceDescriptor(computationNodePtr->ValuePtr()->GetDeviceId())), [](_ReferenceCounter* ptr) { delete ptr; });
|
||||
auto mask = NDMaskPtr((nodeValue->Mask() != nullptr) ? new NDMask(nodeValue->Mask()->Shape(), nodeValue->Mask()->Device()) : nullptr, [](_ReferenceCounter* ptr) { delete ptr; });
|
||||
gradientValuePtr = ValuePtr(new Value(data, mask), [](_ReferenceCounter* ptr) { delete ptr; });
|
||||
}
|
||||
gradientValuePtr->CopyFrom(*nodeValue);
|
||||
break;
|
||||
}
|
||||
case DataType::Double:
|
||||
{
|
||||
auto nodeValue = GetValueObjectFromCNTKImplMatrixAndMBLayout<double>(iter->first, computationNodePtr->As<ComputationNode<double>>()->Gradient(), computationNodePtr->GetMBLayout());
|
||||
if (gradientValuePtr == nullptr)
|
||||
{
|
||||
auto data = NDArrayViewPtr(new NDArrayView(iter->first.DataType(), gradientShape, nullptr, 0, AsDeviceDescriptor(computationNodePtr->ValuePtr()->GetDeviceId())), [](_ReferenceCounter* ptr) { delete ptr; });
|
||||
auto mask = NDMaskPtr((nodeValue->Mask() != nullptr) ? new NDMask(nodeValue->Mask()->Shape(), nodeValue->Mask()->Device()) : nullptr, [](_ReferenceCounter* ptr) { delete ptr; });
|
||||
gradientValuePtr = ValuePtr(new Value(data, mask), [](_ReferenceCounter* ptr) { delete ptr; });
|
||||
|
||||
}
|
||||
gradientValuePtr->CopyFrom(*nodeValue);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
LogicError("Unsupported DataType %s", DataTypeName(iter->first.DataType()));
|
||||
break;
|
||||
}
|
||||
|
||||
gradients[iter->first] = gradientValuePtr;
|
||||
}
|
||||
}
|
||||
|
@ -422,7 +755,7 @@ namespace CNTK
|
|||
|
||||
// TODO: How to deal with the specified 'computeDevice'
|
||||
|
||||
return (outputsToRetainBackwardStateFor.Size() > 0) ? new CNTKBackPropState(this, m_variableToNodeMap[arguments.m_map->begin()->first]->GetEvalTimeStamp()) : nullptr;
|
||||
return (outputsToRetainBackwardStateFor.Size() > 0) ? BackPropStatePtr(new CNTKBackPropState(this, { arguments.m_map->begin()->first, m_variableToNodeMap[arguments.m_map->begin()->first]->GetEvalTimeStamp() }), [](_ReferenceCounter* ptr) { delete ptr; }) : nullptr;
|
||||
}
|
||||
|
||||
/*virtual*/ void CompositeFunction::Backward(const BackPropStatePtr& state,
|
||||
|
@ -434,21 +767,21 @@ namespace CNTK
|
|||
|
||||
// TODO: Support multiple concurrent backprop states
|
||||
auto backpropState = dynamic_cast<const CNTKBackPropState*>(state.GetPtr());
|
||||
if (backpropState->EvalTimeStamp() != m_variableToNodeMap[*(this->Arguments().begin())]->GetEvalTimeStamp())
|
||||
if (backpropState->EvalTimeStamp().second != m_variableToNodeMap[backpropState->EvalTimeStamp().first]->GetEvalTimeStamp())
|
||||
LogicError("The specified backprop state specified cannot be used for backpropagation as the Function's internal state was modified by subsequent Forward calls to the function."
|
||||
"This is not a user error but a shortcoming of the current implementation where multiple independent backprop states are not simultaneously supported");
|
||||
|
||||
// TODO: Avoid copying the data when possible
|
||||
if (rootGradientValues.Size() > 1)
|
||||
LogicError("Currently gradient backprop from only one of the Function Outputs is supported");
|
||||
|
||||
// Feed data into the arguments of the network
|
||||
PopulateNetworkGradients(rootGradientValues);
|
||||
// TODO: Avoid copying the data when possible
|
||||
|
||||
// Zero all gradients of nodes below the root nodes
|
||||
for (auto iter = rootGradientValues.m_map->begin(); iter != rootGradientValues.m_map->end(); ++iter)
|
||||
m_computationNetwork->ZeroInputGradients(m_variableToNodeMap[iter->first]);
|
||||
|
||||
if (rootGradientValues.Size() > 1)
|
||||
LogicError("Currently gradient backprop from only one of the Function Outputs is supported");
|
||||
// Feed data into the arguments of the network
|
||||
PopulateNetworkGradients(rootGradientValues);
|
||||
|
||||
// Backpropagate through the network
|
||||
auto rootComputationNodePtr = m_variableToNodeMap[rootGradientValues.m_map->begin()->first];
|
||||
|
@ -459,6 +792,26 @@ namespace CNTK
|
|||
// TODO: How to deal with the specified 'computeDevice'
|
||||
}
|
||||
|
||||
/*virtual*/ void CompositeFunction::_ReplacePlaceholders(const _Internal::_SimpleMap<Placeholder, Variable>& placeholderReplacements, _Internal::_SimpleSet<const Function*>& visitedFunctions, _Internal::_SimpleSet<Placeholder>& replacedPlaceholders)
|
||||
{
|
||||
RootFunction()->_ReplacePlaceholders(placeholderReplacements, visitedFunctions, replacedPlaceholders);
|
||||
|
||||
// If any of the placeholders were replaced with Output variables, let's add the graph of function underneath each of those to 'm_allPrimitiveFunctions' set
|
||||
for (auto iter = replacedPlaceholders.m_set->begin(); iter != replacedPlaceholders.m_set->end(); ++iter)
|
||||
{
|
||||
auto replacingVariable = placeholderReplacements[*iter];
|
||||
if (replacingVariable.Kind() == VariableKind::Output)
|
||||
{
|
||||
auto ownerFunc = replacingVariable.Owner();
|
||||
_Internal::_SimpleSet<FunctionPtr> visitedFunctions;
|
||||
_DetermineInputs(ownerFunc, visitedFunctions);
|
||||
|
||||
// Add the newly visited functions to 'm_allPrimitiveFunctions' set
|
||||
m_allPrimitiveFunctions.m_set->insert(visitedFunctions.m_set->begin(), visitedFunctions.m_set->end());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
FunctionPtr Times(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return CompositeFunction::Create(new PrimitiveFunction(PrimitiveOpType::Times, { leftOperand, rightOperand }, Dictionary(), name), name);
|
||||
|
@ -474,6 +827,11 @@ namespace CNTK
|
|||
return CompositeFunction::Create(new PrimitiveFunction(PrimitiveOpType::Sigmoid, { operand }, Dictionary(), name), name);
|
||||
}
|
||||
|
||||
FunctionPtr Tanh(const Variable& operand, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return CompositeFunction::Create(new PrimitiveFunction(PrimitiveOpType::Tanh, { operand }, Dictionary(), name), name);
|
||||
}
|
||||
|
||||
FunctionPtr _Combine(const _Internal::_SimpleVector<FunctionPtr>& operands, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
_Internal::_SimpleSet<FunctionPtr> uniqueOperands;
|
||||
|
@ -501,4 +859,39 @@ namespace CNTK
|
|||
{
|
||||
return CompositeFunction::Create(new PrimitiveFunction(PrimitiveOpType::PredictionError, { prediction, labels }, Dictionary(), name), name);
|
||||
}
|
||||
|
||||
FunctionPtr Exp(const Variable& operand, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return CompositeFunction::Create(new PrimitiveFunction(PrimitiveOpType::Exp, { operand }, Dictionary(), name), name);
|
||||
}
|
||||
|
||||
FunctionPtr PastValue(const Variable& initialState, const Variable& operand, size_t stepSize, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
if (operand.DynamicAxes().size() != 1)
|
||||
InvalidArgument("PastValue overload that does not explicitly specify a dynamic axis can only be used for operands with exactly one dynamic axis");
|
||||
|
||||
auto additionalProperties = Dictionary();
|
||||
additionalProperties[L"stepSize"] = DictionaryValue(stepSize);
|
||||
return CompositeFunction::Create(new PrimitiveFunction(PrimitiveOpType::PastValue, { initialState, operand }, std::move(additionalProperties), name), name);
|
||||
}
|
||||
|
||||
FunctionPtr FutureValue(const Variable& initialState, const Variable& operand, size_t stepSize, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
if (operand.DynamicAxes().size() != 1)
|
||||
InvalidArgument("FutureValue overload that does not explicitly specify a dynamic axis can only be used for operands with exactly one dynamic axis");
|
||||
|
||||
auto additionalProperties = Dictionary();
|
||||
additionalProperties[L"stepSize"] = DictionaryValue(stepSize);
|
||||
return CompositeFunction::Create(new PrimitiveFunction(PrimitiveOpType::FutureValue, { initialState, operand }, std::move(additionalProperties), name), name);
|
||||
}
|
||||
|
||||
FunctionPtr ElementTimes(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return CompositeFunction::Create(new PrimitiveFunction(PrimitiveOpType::ElementTimes, { leftOperand, rightOperand }, Dictionary(), name), name);
|
||||
}
|
||||
|
||||
FunctionPtr ReduceSum(const Variable& operand, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return CompositeFunction::Create(new PrimitiveFunction(PrimitiveOpType::ReduceSum, { operand }, Dictionary(), name), name);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,9 +16,15 @@ namespace CNTK
|
|||
Plus,
|
||||
Times,
|
||||
Sigmoid,
|
||||
Tanh,
|
||||
Combine,
|
||||
CrossEntropyWithSoftmax,
|
||||
PredictionError
|
||||
PredictionError,
|
||||
Exp,
|
||||
PastValue,
|
||||
FutureValue,
|
||||
ElementTimes,
|
||||
ReduceSum
|
||||
};
|
||||
|
||||
inline const char* PrimitiveOpTypeName(PrimitiveOpType opType)
|
||||
|
@ -29,12 +35,24 @@ namespace CNTK
|
|||
return "Times";
|
||||
else if (opType == PrimitiveOpType::Sigmoid)
|
||||
return "Sigmoid";
|
||||
else if (opType == PrimitiveOpType::Tanh)
|
||||
return "Tanh";
|
||||
else if (opType == PrimitiveOpType::Combine)
|
||||
return "Combine";
|
||||
else if (opType == PrimitiveOpType::CrossEntropyWithSoftmax)
|
||||
return "CrossEntropyWithSoftmax";
|
||||
else if (opType == PrimitiveOpType::PredictionError)
|
||||
return "PredictionError";
|
||||
else if (opType == PrimitiveOpType::Exp)
|
||||
return "Exp";
|
||||
else if (opType == PrimitiveOpType::PastValue)
|
||||
return "PastValue";
|
||||
else if (opType == PrimitiveOpType::FutureValue)
|
||||
return "FutureValue";
|
||||
else if (opType == PrimitiveOpType::ElementTimes)
|
||||
return "ElementTimes";
|
||||
else if (opType == PrimitiveOpType::ReduceSum)
|
||||
return "ReduceSum";
|
||||
else
|
||||
LogicError("Unknown PrimitiveOpType");
|
||||
}
|
||||
|
@ -50,14 +68,14 @@ namespace CNTK
|
|||
virtual BackPropStatePtr Forward(const _Internal::_SimpleMap<Variable, const ValuePtr>& /*arguments*/,
|
||||
_Internal::_SimpleMap<Variable, ValuePtr>& /*outputs*/,
|
||||
const _Internal::_SimpleSet<Variable>& /*outputsToRetainBackwardStateFor*/,
|
||||
const DeviceDescriptor& /*computeDevice*/)
|
||||
const DeviceDescriptor& /*computeDevice*/) override
|
||||
{
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
virtual void Backward(const BackPropStatePtr& /*state*/,
|
||||
const _Internal::_SimpleMap<Variable, const ValuePtr>& /*rootGradientValues*/,
|
||||
_Internal::_SimpleMap<Variable, ValuePtr>& /*backPropagatedGradientValuesForInputs*/)
|
||||
_Internal::_SimpleMap<Variable, ValuePtr>& /*backPropagatedGradientValuesForInputs*/) override
|
||||
{
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
|
@ -129,7 +147,7 @@ namespace CNTK
|
|||
return NDShape(std::move(outputDims));
|
||||
}
|
||||
|
||||
static NDShape ReductionOpOutputShape(PrimitiveOpType op, const NDShape& operandShape, const std::initializer_list<size_t>& reductionAxes)
|
||||
static NDShape ReductionOpOutputShape(PrimitiveOpType op, const NDShape& operandShape, const std::vector<size_t>& reductionAxes)
|
||||
{
|
||||
if (reductionAxes.size() > operandShape.NumAxes())
|
||||
RuntimeError("The number of reduction axes %d exceeds the number of axes in the operand shape %s of the reduction operation %s", reductionAxes.size(), AsString(operandShape).c_str(), PrimitiveOpTypeName(op));
|
||||
|
@ -155,30 +173,74 @@ namespace CNTK
|
|||
// TODO: We are just using the input[0]'s DataType as output node's DataType. This is not always correct
|
||||
DataType outputDataType = inputs[0].DataType();
|
||||
|
||||
// We currently require that the inputs' dynamic axes if any match
|
||||
std::vector<Axis> outputDynamicAxes = inputs[0].DynamicAxes();
|
||||
for (size_t i = 1; i < inputs.size(); ++i)
|
||||
{
|
||||
auto currentInputDynamicAxes = inputs[i].DynamicAxes();
|
||||
if (outputDynamicAxes.empty())
|
||||
outputDynamicAxes = currentInputDynamicAxes;
|
||||
else
|
||||
{
|
||||
if (!currentInputDynamicAxes.empty() && (currentInputDynamicAxes != outputDynamicAxes))
|
||||
LogicError("Currently if an operand of a binary elementwise operation has any dynamic axes, those must match the dynamic axes of the other operand");
|
||||
}
|
||||
}
|
||||
|
||||
switch (op)
|
||||
{
|
||||
case PrimitiveOpType::Sigmoid:
|
||||
case PrimitiveOpType::Tanh:
|
||||
case PrimitiveOpType::Exp:
|
||||
assert(inputs.size() == 1);
|
||||
outputs.push_back(Variable(UnaryElementwiseOpOutputShape(inputs[0].Shape()), outputDataType, owner));
|
||||
outputs.push_back(Variable(UnaryElementwiseOpOutputShape(inputs[0].Shape()), outputDataType, owner, outputDynamicAxes));
|
||||
break;
|
||||
case PrimitiveOpType::PastValue:
|
||||
case PrimitiveOpType::FutureValue:
|
||||
assert(inputs.size() == 2);
|
||||
outputs.push_back(Variable(UnaryElementwiseOpOutputShape(inputs[1].Shape()), outputDataType, owner, outputDynamicAxes));
|
||||
break;
|
||||
case PrimitiveOpType::Plus:
|
||||
case PrimitiveOpType::ElementTimes:
|
||||
assert(inputs.size() == 2);
|
||||
outputs.push_back(Variable(BinaryElementwiseOpOutputShape(op, inputs[0].Shape(), inputs[1].Shape()), outputDataType, owner));
|
||||
outputs.push_back(Variable(BinaryElementwiseOpOutputShape(op, inputs[0].Shape(), inputs[1].Shape()), outputDataType, owner, outputDynamicAxes));
|
||||
break;
|
||||
case PrimitiveOpType::Times:
|
||||
assert(inputs.size() == 2);
|
||||
outputs.push_back(Variable(TimesOpOutputShape(inputs[0].Shape(), inputs[1].Shape()), outputDataType, owner));
|
||||
outputs.push_back(Variable(TimesOpOutputShape(inputs[0].Shape(), inputs[1].Shape()), outputDataType, owner, outputDynamicAxes));
|
||||
break;
|
||||
case PrimitiveOpType::CrossEntropyWithSoftmax:
|
||||
case PrimitiveOpType::PredictionError:
|
||||
{
|
||||
assert(inputs.size() == 2);
|
||||
|
||||
if (inputs[0].Shape().NumAxes() > 1)
|
||||
InvalidArgument("The shape of input operands for the %s operation should have at most one axis", PrimitiveOpTypeName(op));
|
||||
|
||||
auto predictionShape = inputs[0].Shape();
|
||||
auto labelsShape = inputs[1].Shape();
|
||||
if (predictionShape != labelsShape)
|
||||
RuntimeError("Prediction output operand's shape %s is incompatible with label operand's shape %s for the %s operation", AsString(predictionShape).c_str(), AsString(labelsShape).c_str(), PrimitiveOpTypeName(op));
|
||||
|
||||
outputs.push_back(Variable(ReductionOpOutputShape(op, predictionShape, { 0 }), outputDataType, owner));
|
||||
std::vector<size_t> reductionAxes;
|
||||
for (size_t i = 0; i < inputs[0].Shape().NumAxes(); ++i)
|
||||
reductionAxes.push_back(i);
|
||||
|
||||
outputs.push_back(Variable(ReductionOpOutputShape(op, predictionShape, reductionAxes), outputDataType, owner, {}));
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::ReduceSum:
|
||||
{
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
// TODO: For reductions, we should remove any of the dynamic axes from 'outputDynamicAxes' that are being reduced over.
|
||||
// Currently we only support reductions that reduce over all axes
|
||||
std::vector<Axis> reductionOutputDynamicAxes = {};
|
||||
std::vector<size_t> reductionAxes;
|
||||
for (size_t i = 0; i < inputs[0].Shape().NumAxes(); ++i)
|
||||
reductionAxes.push_back(i);
|
||||
|
||||
outputs.push_back(Variable(ReductionOpOutputShape(op, inputs[0].Shape(), reductionAxes), outputDataType, owner, reductionOutputDynamicAxes));
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::Combine:
|
||||
|
@ -200,17 +262,17 @@ namespace CNTK
|
|||
class CNTKBackPropState final : public BackPropState
|
||||
{
|
||||
public:
|
||||
CNTKBackPropState(const FunctionPtr& function, int64_t evalTimeStamp)
|
||||
CNTKBackPropState(const FunctionPtr& function, const std::pair<Variable, int64_t>& evalTimeStamp)
|
||||
: BackPropState(function), m_evalTimeStamp(evalTimeStamp)
|
||||
{}
|
||||
|
||||
int64_t EvalTimeStamp() const
|
||||
std::pair<Variable, int64_t> EvalTimeStamp() const
|
||||
{
|
||||
return m_evalTimeStamp;
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t m_evalTimeStamp;
|
||||
std::pair<Variable, int64_t> m_evalTimeStamp;
|
||||
};
|
||||
typedef _Internal::_ReferenceCounterSharedPtr<CNTKBackPropState> CNTKBackPropStatePtr;
|
||||
|
||||
|
@ -219,33 +281,44 @@ namespace CNTK
|
|||
|
||||
class CompositeFunction final : public Function
|
||||
{
|
||||
friend class Function;
|
||||
|
||||
public:
|
||||
static CompositeFunctionPtr Create(const FunctionPtr& rootFunction, const std::wstring& name = L"")
|
||||
{
|
||||
_Internal::_SimpleSet<FunctionPtr> visitedFunctions;
|
||||
std::vector<Variable> inputs = DetermineInputs(rootFunction, visitedFunctions);
|
||||
auto func = new CompositeFunction(inputs, rootFunction->Outputs(), rootFunction, std::move(visitedFunctions), name);
|
||||
return CompositeFunctionPtr(func, [](_ReferenceCounter* ptr) {
|
||||
delete ptr;
|
||||
});
|
||||
|
||||
// Call _DetermineInputs to get the set of all functions in the graph
|
||||
_DetermineInputs(rootFunction, visitedFunctions);
|
||||
|
||||
auto func = new CompositeFunction(rootFunction, std::move(visitedFunctions), name);
|
||||
return CompositeFunctionPtr(func, [](_ReferenceCounter* ptr) { delete ptr; });
|
||||
}
|
||||
|
||||
virtual BackPropStatePtr Forward(const _Internal::_SimpleMap<Variable, const ValuePtr>& arguments,
|
||||
_Internal::_SimpleMap<Variable, ValuePtr>& outputs,
|
||||
const _Internal::_SimpleSet<Variable>& outputsToRetainBackwardStateFor,
|
||||
const DeviceDescriptor& computeDevice);
|
||||
const DeviceDescriptor& computeDevice) override;
|
||||
|
||||
virtual void Backward(const BackPropStatePtr& state,
|
||||
const _Internal::_SimpleMap<Variable, const ValuePtr>& rootGradientValues,
|
||||
_Internal::_SimpleMap<Variable, ValuePtr>& backPropagatedGradientValuesForInputs);
|
||||
_Internal::_SimpleMap<Variable, ValuePtr>& backPropagatedGradientValuesForInputs) override;
|
||||
|
||||
private:
|
||||
CompositeFunction(const std::vector<Variable>& inputs, const std::vector<Variable>& outputs, const FunctionPtr& rootFunction, _Internal::_SimpleSet<FunctionPtr>&& allPrimitiveFunctions, const std::wstring& name)
|
||||
: Function(inputs, outputs, rootFunction, name), m_allPrimitiveFunctions(std::move(allPrimitiveFunctions))
|
||||
virtual void _ReplacePlaceholders(const _Internal::_SimpleMap<Placeholder, Variable>& placeholderReplacements, _Internal::_SimpleSet<const Function*>& visitedFunctions, _Internal::_SimpleSet<Placeholder>& replacedPlaceholders) override;
|
||||
|
||||
CompositeFunction(const FunctionPtr& rootFunction, _Internal::_SimpleSet<FunctionPtr>&& allPrimitiveFunctions, const std::wstring& name)
|
||||
: Function({}, rootFunction->Outputs(), rootFunction, name), m_allPrimitiveFunctions(std::move(allPrimitiveFunctions))
|
||||
{
|
||||
}
|
||||
|
||||
static std::vector<Variable> DetermineInputs(const FunctionPtr& rootFunction, _Internal::_SimpleSet<FunctionPtr>& visitedFunctions)
|
||||
std::vector<Variable> DetermineInputs() const
|
||||
{
|
||||
_Internal::_SimpleSet<FunctionPtr> visitedFunctions;
|
||||
return _DetermineInputs(RootFunction(), visitedFunctions);
|
||||
}
|
||||
|
||||
static std::vector<Variable> _DetermineInputs(const FunctionPtr& rootFunction, _Internal::_SimpleSet<FunctionPtr>& visitedFunctions)
|
||||
{
|
||||
visitedFunctions.Insert(rootFunction);
|
||||
|
||||
|
@ -259,7 +332,7 @@ namespace CNTK
|
|||
else if (!visitedFunctions.Contains(currentInput.Owner()))
|
||||
{
|
||||
FunctionPtr function = currentInput.Owner();
|
||||
std::vector<Variable> functionInputs = DetermineInputs(function, visitedFunctions);
|
||||
std::vector<Variable> functionInputs = _DetermineInputs(function, visitedFunctions);
|
||||
std::copy(functionInputs.begin(), functionInputs.end(), std::back_inserter(inputs));
|
||||
}
|
||||
}
|
||||
|
@ -271,10 +344,10 @@ namespace CNTK
|
|||
Microsoft::MSR::CNTK::ComputationNetworkPtr GetComputationNetwork(const DeviceDescriptor& device, const _Internal::_SimpleSet<Variable>& backpropRoots);
|
||||
|
||||
template <typename ElementType>
|
||||
static Microsoft::MSR::CNTK::ComputationNodeBasePtr GetOutputVariableNode(const Variable& variable, Microsoft::MSR::CNTK::ComputationNetworkBuilder<ElementType>& builder, std::unordered_map<Variable, Microsoft::MSR::CNTK::ComputationNodeBasePtr>& variableToNodeMap, std::unordered_map<Variable, bool>& isVariableRootMap);
|
||||
static Microsoft::MSR::CNTK::ComputationNodeBasePtr GetOutputVariableNode(const Variable& variable, Microsoft::MSR::CNTK::ComputationNetworkPtr& network, Microsoft::MSR::CNTK::ComputationNetworkBuilder<ElementType>& builder, std::unordered_map<Variable, Microsoft::MSR::CNTK::ComputationNodeBasePtr>& variableToNodeMap, std::unordered_map<Variable, bool>& isVariableRootMap);
|
||||
|
||||
template <typename ElementType>
|
||||
static Microsoft::MSR::CNTK::ComputationNodeBasePtr GetNode(const Variable& variable, Microsoft::MSR::CNTK::ComputationNetworkBuilder<ElementType>& builder, std::unordered_map<Variable, Microsoft::MSR::CNTK::ComputationNodeBasePtr>& variableToNodeMap, std::unordered_map<Variable, bool>& isVariableRootMap);
|
||||
static Microsoft::MSR::CNTK::ComputationNodeBasePtr GetNode(const Variable& variable, Microsoft::MSR::CNTK::ComputationNetworkPtr& network, Microsoft::MSR::CNTK::ComputationNetworkBuilder<ElementType>& builder, std::unordered_map<Variable, Microsoft::MSR::CNTK::ComputationNodeBasePtr>& variableToNodeMap, std::unordered_map<Variable, bool>& isVariableRootMap);
|
||||
|
||||
void PopulateNetworkInputs(const _Internal::_SimpleMap<Variable, const ValuePtr>& arguments);
|
||||
void PopulateNetworkGradients(const _Internal::_SimpleMap<Variable, const ValuePtr>& gradients);
|
||||
|
@ -282,10 +355,11 @@ namespace CNTK
|
|||
void GetNetworkOutputs(std::unordered_map<Variable, ValuePtr>& outputs);
|
||||
void GetNetworkGradients(std::unordered_map<Variable, ValuePtr>& gradients);
|
||||
|
||||
static void CopyNDArrayViewToComputationNodeValue(const NDArrayViewPtr& arrayView, Microsoft::MSR::CNTK::ComputationNodeBasePtr node);
|
||||
static void CopyNDArrayViewToComputationNodeGradient(const NDArrayViewPtr& arrayView, Microsoft::MSR::CNTK::ComputationNodeBasePtr node);
|
||||
template <typename ElementType>
|
||||
static std::pair<std::shared_ptr<const Microsoft::MSR::CNTK::Matrix<ElementType>>, Microsoft::MSR::CNTK::MBLayoutPtr> GetCNTKImplMatrixAndMBLayoutFromValueObject(Variable var, const ValuePtr& value);
|
||||
|
||||
static void CopyComputationNodeDataToNDArrayView(const Microsoft::MSR::CNTK::ComputationNodeBasePtr& node, NDArrayViewPtr arrayView, bool copyGradient);
|
||||
template <typename ElementType>
|
||||
static ValuePtr GetValueObjectFromCNTKImplMatrixAndMBLayout(Variable var, const Microsoft::MSR::CNTK::Matrix<ElementType>& matrix, const Microsoft::MSR::CNTK::MBLayoutPtr& layout);
|
||||
|
||||
private:
|
||||
_Internal::_SimpleSet<FunctionPtr> m_allPrimitiveFunctions;
|
||||
|
|
|
@ -62,21 +62,21 @@ void TestFeedForwardNetworkCreation(const DeviceDescriptor& device)
|
|||
size_t iterationCount = 4;
|
||||
unsigned int randSeed = 2;
|
||||
srand(randSeed);
|
||||
size_t numSamples = 1;
|
||||
size_t numSamples = 3;
|
||||
for (size_t i = 0; i < iterationCount; ++i)
|
||||
{
|
||||
std::vector<float> inputData(inputDim * numSamples);
|
||||
for (size_t i = 0; i < inputData.size(); ++i)
|
||||
inputData[i] = ((float)rand()) / RAND_MAX;
|
||||
|
||||
NDShape inputShape = { inputDim, numSamples };
|
||||
NDShape inputShape = { inputDim, 1, numSamples };
|
||||
ValuePtr inputValue = new Value(new NDArrayView(inputShape, inputData.data(), inputData.size(), DeviceDescriptor::CPUDevice(), true));
|
||||
|
||||
std::vector<float> labelData(numOutputClasses * numSamples, 0);
|
||||
for (size_t i = 0; i < numSamples; ++i)
|
||||
labelData[(i*numOutputClasses) + (rand() % numOutputClasses)] = 1;
|
||||
|
||||
NDShape labelShape = { numOutputClasses, numSamples };
|
||||
NDShape labelShape = { numOutputClasses, 1, numSamples };
|
||||
ValuePtr labelValue = new Value(new NDArrayView(labelShape, labelData.data(), labelData.size(), DeviceDescriptor::CPUDevice(), true));
|
||||
|
||||
ValuePtr outputValue, predictionErrorValue;
|
||||
|
@ -119,10 +119,10 @@ void TestTimesAndPlus(size_t inputDim,
|
|||
for (size_t i = 0; i < inputData.size(); ++i)
|
||||
inputData[i] = ((ElementType)rand()) / RAND_MAX;
|
||||
|
||||
NDShape inputShape = { inputDim, numSamples };
|
||||
NDShape inputShape = { inputDim, 1, numSamples };
|
||||
ValuePtr inputValue = new Value(new NDArrayView(inputShape, inputData.data(), inputData.size(), DeviceDescriptor::CPUDevice(), true));
|
||||
|
||||
NDShape outputShape = { outputDim, numSamples };
|
||||
NDShape outputShape = { outputDim, 1, numSamples };
|
||||
std::vector<ElementType> outputData(outputShape.TotalSize());
|
||||
ValuePtr outputValue;
|
||||
if (usePreAllocatedOutputs)
|
||||
|
|
|
@ -88,12 +88,14 @@
|
|||
<GenerateDebugInformation>true</GenerateDebugInformation>
|
||||
<EnableCOMDATFolding>true</EnableCOMDATFolding>
|
||||
<OptimizeReferences>true</OptimizeReferences>
|
||||
<AdditionalDependencies Condition="'$(Configuration)|$(Platform)'=='Release|x64'">CNTKLibrary-2.0.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
</Link>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<ClCompile Include="FeedForwardTests.cpp" />
|
||||
<ClCompile Include="Main.cpp" />
|
||||
<ClCompile Include="NDArrayViewTests.cpp" />
|
||||
<ClCompile Include="RecurrentFunctionTests.cpp" />
|
||||
<ClCompile Include="TensorTests.cpp" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
|
|
|
@ -27,6 +27,9 @@
|
|||
<ClCompile Include="TensorTests.cpp">
|
||||
<Filter>Source Files</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="RecurrentFunctionTests.cpp">
|
||||
<Filter>Source Files</Filter>
|
||||
</ClCompile>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="Common.h">
|
||||
|
|
|
@ -4,10 +4,12 @@
|
|||
void NDArrayViewTests();
|
||||
void TensorTests();
|
||||
void FeedForwardTests();
|
||||
void RecurrentFunctionTests();
|
||||
|
||||
int main()
|
||||
{
|
||||
NDArrayViewTests();
|
||||
TensorTests();
|
||||
FeedForwardTests();
|
||||
RecurrentFunctionTests();
|
||||
}
|
||||
|
|
|
@ -0,0 +1,440 @@
|
|||
#include "CNTKLibrary.h"
|
||||
#include <functional>
|
||||
#include "Common.h"
|
||||
#include <numeric>
|
||||
|
||||
using namespace CNTK;
|
||||
|
||||
static unsigned long seed = 1;
|
||||
|
||||
template <typename ElementType>
|
||||
std::pair<FunctionPtr, FunctionPtr> LSTMPCellWithSelfStabilization(Variable input, Variable prevOutput, Variable prevCellState, const DeviceDescriptor& device)
|
||||
{
|
||||
assert(input.Shape().NumAxes() == 1);
|
||||
size_t inputDim = input.Shape()[0];
|
||||
|
||||
size_t outputDim = prevOutput.Shape()[0];
|
||||
size_t cellDim = prevCellState.Shape()[0];
|
||||
|
||||
auto Wxo = Parameter(NDArrayView::RandomUniform<ElementType>({ cellDim, inputDim }, -0.5, 0.5, seed++, device));
|
||||
auto Wxi = Parameter(NDArrayView::RandomUniform<ElementType>({ cellDim, inputDim }, -0.5, 0.5, seed++, device));
|
||||
auto Wxf = Parameter(NDArrayView::RandomUniform<ElementType>({ cellDim, inputDim }, -0.5, 0.5, seed++, device));
|
||||
auto Wxc = Parameter(NDArrayView::RandomUniform<ElementType>({ cellDim, inputDim }, -0.5, 0.5, seed++, device));
|
||||
|
||||
auto Bo = Parameter({ cellDim }, (ElementType)0.0, device);
|
||||
auto Bc = Parameter({ cellDim }, (ElementType)0.0, device);
|
||||
auto Bi = Parameter({ cellDim }, (ElementType)0.0, device);
|
||||
auto Bf = Parameter({ cellDim }, (ElementType)0.0, device);
|
||||
|
||||
auto Whi = Parameter(NDArrayView::RandomUniform<ElementType>({ cellDim, outputDim }, -0.5, 0.5, seed++, device));
|
||||
auto Wci = Parameter(NDArrayView::RandomUniform<ElementType>({ cellDim }, -0.5, 0.5, seed++, device));
|
||||
|
||||
auto Whf = Parameter(NDArrayView::RandomUniform<ElementType>({ cellDim, outputDim }, -0.5, 0.5, seed++, device));
|
||||
auto Wcf = Parameter(NDArrayView::RandomUniform<ElementType>({ cellDim }, -0.5, 0.5, seed++, device));
|
||||
|
||||
auto Who = Parameter(NDArrayView::RandomUniform<ElementType>({ cellDim, outputDim }, -0.5, 0.5, seed++, device));
|
||||
auto Wco = Parameter(NDArrayView::RandomUniform<ElementType>({ cellDim }, -0.5, 0.5, seed++, device));
|
||||
|
||||
auto Whc = Parameter(NDArrayView::RandomUniform<ElementType>({ cellDim, outputDim }, -0.5, 0.5, seed++, device));
|
||||
|
||||
auto Wmr = Parameter(NDArrayView::RandomUniform<ElementType>({ outputDim, cellDim }, -0.5, 0.5, seed++, device));
|
||||
|
||||
// Stabilization by routing input through an extra scalar parameter
|
||||
auto sWxo = Parameter({}, (ElementType)0.0, device);
|
||||
auto sWxi = Parameter({}, (ElementType)0.0, device);
|
||||
auto sWxf = Parameter({}, (ElementType)0.0, device);
|
||||
auto sWxc = Parameter({}, (ElementType)0.0, device);
|
||||
|
||||
auto sWhi = Parameter({}, (ElementType)0.0, device);
|
||||
auto sWci = Parameter({}, (ElementType)0.0, device);
|
||||
|
||||
auto sWhf = Parameter({}, (ElementType)0.0, device);
|
||||
auto sWcf = Parameter({}, (ElementType)0.0, device);
|
||||
auto sWho = Parameter({}, (ElementType)0.0, device);
|
||||
auto sWco = Parameter({}, (ElementType)0.0, device);
|
||||
auto sWhc = Parameter({}, (ElementType)0.0, device);
|
||||
|
||||
auto sWmr = Parameter({}, (ElementType)0.0, device);
|
||||
|
||||
auto expsWxo = Exp(sWxo);
|
||||
auto expsWxi = Exp(sWxi);
|
||||
auto expsWxf = Exp(sWxf);
|
||||
auto expsWxc = Exp(sWxc);
|
||||
|
||||
auto expsWhi = Exp(sWhi);
|
||||
auto expsWci = Exp(sWci);
|
||||
|
||||
auto expsWhf = Exp(sWhf);
|
||||
auto expsWcf = Exp(sWcf);
|
||||
auto expsWho = Exp(sWho);
|
||||
auto expsWco = Exp(sWco);
|
||||
auto expsWhc = Exp(sWhc);
|
||||
|
||||
auto expsWmr = Exp(sWmr);
|
||||
|
||||
auto Wxix = Times(Wxi, ElementTimes(expsWxi, input));
|
||||
auto Whidh = Times(Whi, ElementTimes(expsWhi, prevOutput));
|
||||
auto Wcidc = ElementTimes(Wci, ElementTimes(expsWci, prevCellState));
|
||||
|
||||
auto it = Sigmoid(Plus(Plus(Plus(Wxix, Bi), Whidh), Wcidc));
|
||||
|
||||
auto Wxcx = Times(Wxc, ElementTimes(expsWxc, input));
|
||||
auto Whcdh = Times(Whc, ElementTimes(expsWhc, prevOutput));
|
||||
auto bit = ElementTimes(it, Tanh(Plus(Wxcx, Plus(Whcdh, Bc))));
|
||||
|
||||
auto Wxfx = Times(Wxf, ElementTimes(expsWxf, input));
|
||||
auto Whfdh = Times(Whf, ElementTimes(expsWhf, prevOutput));
|
||||
auto Wcfdc = ElementTimes(Wcf, ElementTimes(expsWcf, prevCellState));
|
||||
|
||||
auto ft = Sigmoid(Plus(Plus(Plus(Wxfx, Bf), Whfdh), Wcfdc));
|
||||
|
||||
auto bft = ElementTimes(ft, prevCellState);
|
||||
|
||||
auto ct = Plus(bft, bit);
|
||||
|
||||
auto Wxox = Times(Wxo, ElementTimes(expsWxo, input));
|
||||
auto Whodh = Times(Who, ElementTimes(expsWho, prevOutput));
|
||||
auto Wcoct = ElementTimes(Wco, ElementTimes(expsWco, ct));
|
||||
|
||||
auto ot = Sigmoid(Plus(Plus(Plus(Wxox, Bo), Whodh), Wcoct));
|
||||
|
||||
auto mt = ElementTimes(ot, Tanh(ct));
|
||||
|
||||
return{ Times(Wmr, ElementTimes(expsWmr, mt)), ct };
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
FunctionPtr LSTMPComponentWithSelfStabilization(Variable input, size_t outputDim, size_t cellDim, const DeviceDescriptor& device)
|
||||
{
|
||||
auto dh = Placeholder({ outputDim });
|
||||
auto dc = Placeholder({ cellDim });
|
||||
|
||||
auto LSTMCell = LSTMPCellWithSelfStabilization<ElementType>(input, dh, dc, device);
|
||||
|
||||
auto actualDh = PastValue(Constant({}, (ElementType)0.0, device), LSTMCell.first, 1);
|
||||
auto actualDc = PastValue(Constant({}, (ElementType)0.0, device), LSTMCell.second, 1);
|
||||
|
||||
// Form the recurrence loop by replacing the dh and dc placeholders with the actualDh and actualDc
|
||||
return LSTMCell.first->ReplacePlaceholders({ { dh, actualDh }, { dc, actualDc } });
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
FunctionPtr LSTMNet(Variable features, size_t cellDim, size_t hiddenDim, size_t numOutputClasses, size_t numLSTMLayers, const DeviceDescriptor& device)
|
||||
{
|
||||
assert(numLSTMLayers >= 1);
|
||||
auto classifierRoot = LSTMPComponentWithSelfStabilization<ElementType>(features, hiddenDim, cellDim, device);
|
||||
for (size_t i = 1; i < numLSTMLayers; ++i) {
|
||||
classifierRoot = LSTMPComponentWithSelfStabilization<ElementType>(classifierRoot, hiddenDim, cellDim, device);
|
||||
}
|
||||
|
||||
auto W = Parameter(NDArrayView::RandomUniform<ElementType>({ numOutputClasses, hiddenDim }, -0.5, 0.5, seed++, device));
|
||||
auto b = Parameter({ numOutputClasses }, (ElementType)0.0, device);
|
||||
|
||||
auto sW = Parameter({}, (ElementType)0.0, device);
|
||||
auto expsW = Exp(sW);
|
||||
|
||||
return Plus(Times(W, ElementTimes(expsW, classifierRoot)), b);
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
void TestRecurrentNetworkCreation(const DeviceDescriptor& device)
|
||||
{
|
||||
const size_t inputDim = 937;
|
||||
const size_t numLSTMLayers = 3;
|
||||
const size_t cellDim = 1024;
|
||||
const size_t hiddenDim = 512;
|
||||
const size_t numOutputClasses = 9304;
|
||||
|
||||
Variable features({ inputDim }, GetDataType<ElementType>(), L"Features");
|
||||
auto classifierOutputFunction = LSTMNet<ElementType>(features, cellDim, hiddenDim, numOutputClasses, numLSTMLayers, device);
|
||||
|
||||
Variable labelsVar = Variable({ numOutputClasses }, GetDataType<ElementType>(), L"Labels");
|
||||
auto trainingLossFunction = CrossEntropyWithSoftmax(classifierOutputFunction, labelsVar, L"lossFunction");
|
||||
auto predictionFunction = PredictionError(classifierOutputFunction, labelsVar, L"predictionError");
|
||||
|
||||
auto LSTMClassifier = Combine({ trainingLossFunction, predictionFunction, classifierOutputFunction }, L"LSTMClassifier");
|
||||
|
||||
// Now test the structure
|
||||
if (LSTMClassifier->Arguments().size() != 2)
|
||||
throw std::exception("TestFeedForwardNetworkCreation: Function does not have expected Argument count");
|
||||
|
||||
if (LSTMClassifier->Outputs().size() != 3)
|
||||
throw std::exception("TestFeedForwardNetworkCreation: Function does not have expected Output count");
|
||||
|
||||
if (LSTMClassifier->Parameters().size() != ((numLSTMLayers * 28) + 3))
|
||||
throw std::exception("TestFeedForwardNetworkCreation: Function does not have expected Parameter count");
|
||||
|
||||
// Run Forward and backward a few times
|
||||
size_t iterationCount = 3;
|
||||
unsigned int randSeed = 2;
|
||||
srand(randSeed);
|
||||
size_t numSequences = 7;
|
||||
size_t maxAllowedSequenceLength = 11;
|
||||
for (size_t i = 0; i < iterationCount; ++i)
|
||||
{
|
||||
std::vector<size_t> sequenceLengths(numSequences);
|
||||
size_t maxActualSequenceLength = 0;
|
||||
for (size_t i = 0; i < numSequences; ++i)
|
||||
{
|
||||
sequenceLengths[i] = (rand() % maxAllowedSequenceLength) + 1;
|
||||
if (sequenceLengths[i] > maxActualSequenceLength)
|
||||
maxActualSequenceLength = sequenceLengths[i];
|
||||
}
|
||||
|
||||
std::vector<const std::vector<ElementType>> inputSequences;
|
||||
for (size_t i = 0; i < numSequences; ++i)
|
||||
{
|
||||
std::vector<ElementType> currentSequence(inputDim * sequenceLengths[i]);
|
||||
for (size_t j = 0; j < currentSequence.size(); ++j)
|
||||
currentSequence[j] = ((ElementType)rand()) / RAND_MAX;
|
||||
|
||||
inputSequences.push_back(std::move(currentSequence));
|
||||
}
|
||||
|
||||
ValuePtr inputValue = Value::Create({ inputDim }, inputSequences, device, true);
|
||||
|
||||
std::vector<const std::vector<ElementType>> labelsData;
|
||||
for (size_t i = 0; i < numSequences; ++i)
|
||||
{
|
||||
std::vector<ElementType> currentSequence(numOutputClasses * sequenceLengths[i]);
|
||||
for (size_t j = 0; j < sequenceLengths[i]; ++j)
|
||||
currentSequence[(j * numOutputClasses) + (rand() % numOutputClasses)] = 1;
|
||||
|
||||
labelsData.push_back(std::move(currentSequence));
|
||||
}
|
||||
|
||||
ValuePtr labelValue = Value::Create({ numOutputClasses }, labelsData, device, true);
|
||||
|
||||
ValuePtr outputValue, predictionErrorValue;
|
||||
std::unordered_map<Variable, ValuePtr> outputs = { { classifierOutputFunction->Output(), outputValue }, { predictionFunction->Output(), predictionErrorValue } };
|
||||
auto backpropState = LSTMClassifier->Forward({ { features, inputValue }, { labelsVar, labelValue } }, outputs, device, { trainingLossFunction->Output() });
|
||||
|
||||
// Perform backprop
|
||||
NDShape outputShape = trainingLossFunction->Output().Shape();
|
||||
std::vector<ElementType> rootGradientsData(outputShape.TotalSize(), 1);
|
||||
ValuePtr rootGradientValue = new Value(new NDArrayView(outputShape, rootGradientsData.data(), rootGradientsData.size(), DeviceDescriptor::CPUDevice(), true));
|
||||
std::unordered_map<Variable, ValuePtr> paramGradients;
|
||||
auto allParams = LSTMClassifier->Parameters();
|
||||
for (auto iter = allParams.begin(); iter != allParams.end(); ++iter)
|
||||
paramGradients[*iter] = nullptr;
|
||||
|
||||
LSTMClassifier->Backward(backpropState, { { trainingLossFunction->Output(), rootGradientValue } }, paramGradients);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
void TestSimpleRecurrence(size_t inputDim,
|
||||
size_t outputDim,
|
||||
size_t maxAllowedSequenceLength,
|
||||
size_t numSequences,
|
||||
const DeviceDescriptor& device,
|
||||
size_t numIterations,
|
||||
bool useFutureValue,
|
||||
unsigned int seed = 1)
|
||||
{
|
||||
Parameter timesParam(new NDArrayView((ElementType)0.5, { outputDim, inputDim }, device));
|
||||
Parameter plusParam(new NDArrayView((ElementType)0.1, { outputDim }, device));
|
||||
|
||||
Variable inputVar({ inputDim }, GetDataType<ElementType>(), true, L"input");
|
||||
|
||||
auto placeholder = Placeholder({ outputDim });
|
||||
auto plusOutput = Plus(plusParam, Plus(placeholder, Times(timesParam, inputVar)));
|
||||
FunctionPtr placeholderReplacement;
|
||||
if (useFutureValue)
|
||||
placeholderReplacement = FutureValue(Constant({}, (ElementType)0.0, device), plusOutput, 1);
|
||||
else
|
||||
placeholderReplacement = PastValue(Constant({}, (ElementType)0.0, device), plusOutput, 1);
|
||||
|
||||
plusOutput = plusOutput->ReplacePlaceholders({ { placeholder, placeholderReplacement } });
|
||||
|
||||
auto reducedOutput = ReduceSum(plusOutput);
|
||||
|
||||
auto rootFunc = Combine({ reducedOutput, plusOutput });
|
||||
|
||||
srand(seed);
|
||||
for (size_t iterIdx = 0; iterIdx < numIterations; ++iterIdx)
|
||||
{
|
||||
std::vector<size_t> sequenceLengths(numSequences);
|
||||
size_t maxActualSequenceLength = 0;
|
||||
for (size_t i = 0; i < numSequences; ++i)
|
||||
{
|
||||
sequenceLengths[i] = (rand() % maxAllowedSequenceLength) + 1;
|
||||
if (sequenceLengths[i] > maxActualSequenceLength)
|
||||
maxActualSequenceLength = sequenceLengths[i];
|
||||
}
|
||||
|
||||
size_t totalNumInputSamples = maxActualSequenceLength * numSequences;
|
||||
std::vector<ElementType> inputData(inputDim * totalNumInputSamples, std::numeric_limits<ElementType>::quiet_NaN());
|
||||
for (size_t i = 0; i < numSequences; ++i)
|
||||
{
|
||||
for (size_t j = 0; j < maxActualSequenceLength; ++j)
|
||||
{
|
||||
size_t sampleIdx = (i * maxActualSequenceLength) + j;
|
||||
for (size_t k = 0; k < inputDim; ++k)
|
||||
{
|
||||
if (j < sequenceLengths[i])
|
||||
inputData[(sampleIdx * inputDim) + k] = ((ElementType)rand()) / RAND_MAX;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
NDShape inputShape = { inputDim, maxActualSequenceLength, numSequences };
|
||||
|
||||
NDMaskPtr inputMask = new NDMask({ maxActualSequenceLength, numSequences }, DeviceDescriptor::CPUDevice());
|
||||
for (size_t i = 0; i < numSequences; ++i)
|
||||
inputMask->MaskSection({ sequenceLengths[i], i }, {NDShape::InferredDimension, 1});
|
||||
|
||||
ValuePtr inputValue = new Value(new NDArrayView(inputShape, inputData.data(), inputData.size(), DeviceDescriptor::CPUDevice(), true), inputMask);
|
||||
|
||||
NDShape reducedOutputShape = {};
|
||||
std::vector<ElementType> reducedOutputData(reducedOutputShape.TotalSize());
|
||||
ValuePtr reducedOutputValue = new Value(new NDArrayView(reducedOutputShape, reducedOutputData.data(), reducedOutputData.size(), DeviceDescriptor::CPUDevice(), false));
|
||||
|
||||
NDShape plusOutputShape = plusOutput->Output().Shape().AppendShape({ maxActualSequenceLength, numSequences });
|
||||
std::vector<ElementType> plusOutputData(plusOutputShape.TotalSize());
|
||||
ValuePtr plusOutputValue = new Value(new NDArrayView(plusOutputShape, plusOutputData.data(), plusOutputData.size(), DeviceDescriptor::CPUDevice(), false), new NDMask(inputMask->Shape(), inputMask->Device()));
|
||||
|
||||
std::unordered_map<Variable, ValuePtr> outputs = { { reducedOutput->Output(), reducedOutputValue }, { plusOutput->Output(), plusOutputValue } };
|
||||
auto backpropState = rootFunc->Forward({ { inputVar, inputValue } }, outputs, device, { plusOutput->Output() });
|
||||
|
||||
// Perform backprop
|
||||
std::vector<ElementType> rootGradientsData(plusOutputShape.TotalSize(), std::numeric_limits<ElementType>::quiet_NaN());
|
||||
for (size_t i = 0; i < numSequences; ++i)
|
||||
{
|
||||
for (size_t j = 0; j < maxActualSequenceLength; ++j)
|
||||
{
|
||||
size_t sampleIdx = (i * maxActualSequenceLength) + j;
|
||||
for (size_t k = 0; k < outputDim; ++k)
|
||||
{
|
||||
if (j < sequenceLengths[i])
|
||||
rootGradientsData[(sampleIdx * outputDim) + k] = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ValuePtr rootGradientValue = new Value(new NDArrayView(plusOutputShape, rootGradientsData.data(), rootGradientsData.size(), DeviceDescriptor::CPUDevice(), true), inputMask->DeepClone());
|
||||
|
||||
std::vector<ElementType> plusParameterGradientData(plusParam.Shape().TotalSize());
|
||||
std::vector<ElementType> timesParameterGradientData(timesParam.Shape().TotalSize());
|
||||
std::vector<ElementType> inputGradientData(inputShape.TotalSize());
|
||||
ValuePtr plusParameterGradientValue = new Value(new NDArrayView(plusParam.Shape(), plusParameterGradientData.data(), plusParameterGradientData.size(), DeviceDescriptor::CPUDevice(), false));
|
||||
ValuePtr timesParameterGradientValue = new Value(new NDArrayView(timesParam.Shape(), timesParameterGradientData.data(), timesParameterGradientData.size(), DeviceDescriptor::CPUDevice(), false));
|
||||
ValuePtr inputGradientValue = new Value(new NDArrayView(inputShape, inputGradientData.data(), inputGradientData.size(), DeviceDescriptor::CPUDevice(), false), inputMask->DeepClone());
|
||||
|
||||
std::unordered_map<Variable, ValuePtr> outGradients = { { inputVar, inputGradientValue }, { plusParam, plusParameterGradientValue }, { timesParam, timesParameterGradientValue } };
|
||||
rootFunc->Backward(backpropState, { { plusOutput->Output(), rootGradientValue } }, outGradients);
|
||||
|
||||
// Verify forward prop results
|
||||
std::vector<ElementType> expectedPlusOutputData(plusOutputShape.TotalSize(), 0);
|
||||
ElementType expectedReducedValue = 0;
|
||||
for (size_t i = 0; i < numSequences; ++i)
|
||||
{
|
||||
size_t currentSequenceLength = sequenceLengths[i];
|
||||
if (useFutureValue)
|
||||
{
|
||||
for (int j = (int)(currentSequenceLength - 1); j >= 0; j--)
|
||||
{
|
||||
ElementType value = 0;
|
||||
for (size_t k = 0; k < inputDim; ++k)
|
||||
value += (ElementType)(0.5 * inputData[(((i * maxActualSequenceLength) + j) * inputDim) + k]);
|
||||
|
||||
for (size_t k = 0; k < outputDim; ++k)
|
||||
{
|
||||
expectedPlusOutputData[(((i * maxActualSequenceLength) + j) * outputDim) + k] = (ElementType)(value + 0.1);
|
||||
|
||||
if (j != (currentSequenceLength - 1))
|
||||
expectedPlusOutputData[(((i * maxActualSequenceLength) + j) * outputDim) + k] += expectedPlusOutputData[(((i * maxActualSequenceLength) + (j + 1)) * outputDim) + k];
|
||||
}
|
||||
|
||||
expectedReducedValue += (outputDim * (ElementType)((value + 0.1) * (j + 1)));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (size_t j = 0; j < currentSequenceLength; j++)
|
||||
{
|
||||
ElementType value = 0;
|
||||
for (size_t k = 0; k < inputDim; ++k)
|
||||
value += (ElementType)(0.5 * inputData[(((i * maxActualSequenceLength) + j) * inputDim) + k]);
|
||||
|
||||
for (size_t k = 0; k < outputDim; ++k)
|
||||
{
|
||||
expectedPlusOutputData[(((i * maxActualSequenceLength) + j) * outputDim) + k] = (ElementType)(value + 0.1);
|
||||
|
||||
if (j != 0)
|
||||
expectedPlusOutputData[(((i * maxActualSequenceLength) + j) * outputDim) + k] += expectedPlusOutputData[(((i * maxActualSequenceLength) + (j - 1)) * outputDim) + k];
|
||||
}
|
||||
|
||||
expectedReducedValue += (outputDim * (ElementType)((value + 0.1) * (currentSequenceLength - j)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
FloatingPointVectorCompare(reducedOutputData, std::vector<ElementType>({ expectedReducedValue }), "TestTimesAndPlus: Forward prop results do not match expected results");
|
||||
FloatingPointVectorCompare(plusOutputData, expectedPlusOutputData, "TestTimesAndPlus: Forward prop results do not match expected results");
|
||||
|
||||
// Verify backward prop results
|
||||
ElementType expectedPlusParameterGradientValue = 0;
|
||||
for (size_t i = 0; i < numSequences; ++i)
|
||||
{
|
||||
size_t currentSequenceLength = sequenceLengths[i];
|
||||
expectedPlusParameterGradientValue += (currentSequenceLength * (currentSequenceLength + 1)) / 2;
|
||||
}
|
||||
|
||||
for (size_t k = 0; k < plusParam.Shape().TotalSize(); ++k)
|
||||
if (plusParameterGradientData[k] != expectedPlusParameterGradientValue)
|
||||
throw std::exception("TestSimpleRecurrence: Backprop prop results do not match expected results for Plus params gradients");
|
||||
|
||||
std::vector<ElementType> expectedTimesParamsGradientValues(timesParam.Shape().TotalSize(), 0);
|
||||
for (size_t i = 0; i < numSequences; ++i)
|
||||
{
|
||||
size_t currentSequenceLength = sequenceLengths[i];
|
||||
for (size_t k = 0; k < inputDim; ++k)
|
||||
{
|
||||
ElementType gradVal = 0;
|
||||
for (size_t j = 0; j < currentSequenceLength; j++)
|
||||
{
|
||||
if (useFutureValue)
|
||||
gradVal += (j + 1) * inputData[(((i * maxActualSequenceLength) + j) * inputDim) + k];
|
||||
else
|
||||
gradVal += (currentSequenceLength - j) * inputData[(((i * maxActualSequenceLength) + j) * inputDim) + k];
|
||||
}
|
||||
|
||||
for (size_t j = 0; j < outputDim; ++j)
|
||||
expectedTimesParamsGradientValues[(k * outputDim) + j] += gradVal;
|
||||
}
|
||||
}
|
||||
|
||||
FloatingPointVectorCompare(timesParameterGradientData, expectedTimesParamsGradientValues, "TestSimpleRecurrence: Backprop prop results do not match expected results for Times params gradients");
|
||||
|
||||
std::vector<ElementType> expectedInputGradientValues(inputShape.TotalSize(), 0);
|
||||
for (size_t i = 0; i < numSequences; ++i)
|
||||
{
|
||||
size_t currentSequenceLength = sequenceLengths[i];
|
||||
for (size_t j = 0; j < currentSequenceLength; j++)
|
||||
{
|
||||
ElementType gradVal = 0;
|
||||
for (size_t k = 0; k < outputDim; ++k)
|
||||
{
|
||||
if (useFutureValue)
|
||||
gradVal += (ElementType)((j + 1) * 0.5);
|
||||
else
|
||||
gradVal += (ElementType)((currentSequenceLength - j) * 0.5);
|
||||
}
|
||||
|
||||
for (size_t k = 0; k < inputDim; ++k)
|
||||
expectedInputGradientValues[(((i * maxActualSequenceLength) + j) * inputDim) + k] = gradVal;
|
||||
}
|
||||
}
|
||||
|
||||
FloatingPointVectorCompare(inputGradientData, expectedInputGradientValues, "TestSimpleRecurrence: Backprop prop results do not match expected results for Times params gradients");
|
||||
}
|
||||
}
|
||||
|
||||
void RecurrentFunctionTests()
|
||||
{
|
||||
TestSimpleRecurrence<float>(2, 1, 4, 1, DeviceDescriptor::CPUDevice(), 3, false);
|
||||
TestSimpleRecurrence<double>(11, 9, 16, 7, DeviceDescriptor::GPUDevice(0), 5, true);
|
||||
TestRecurrentNetworkCreation<float>(DeviceDescriptor::GPUDevice(0));
|
||||
TestRecurrentNetworkCreation<double>(DeviceDescriptor::CPUDevice());
|
||||
}
|
|
@ -29,17 +29,17 @@ void TestTensorPlus(size_t numAxesLeftOperand, size_t numAxesRightOperand, const
|
|||
for (size_t i = 0; i < leftInputData.size(); ++i)
|
||||
leftInputData[i] = ((ElementType)rand()) / RAND_MAX;
|
||||
|
||||
auto leftInputValueShape = leftInputShape.AppendShape({ 1 });
|
||||
auto leftInputValueShape = leftInputShape.AppendShape({ 1, 1 });
|
||||
ValuePtr leftInputValue = new Value(new NDArrayView(leftInputValueShape, leftInputData, true));
|
||||
|
||||
std::vector<ElementType> rightInputData(rightInputShape.TotalSize());
|
||||
for (size_t i = 0; i < rightInputData.size(); ++i)
|
||||
rightInputData[i] = ((ElementType)rand()) / RAND_MAX;
|
||||
|
||||
auto rightInputValueShape = rightInputShape.AppendShape({ 1 });
|
||||
auto rightInputValueShape = rightInputShape.AppendShape({ 1, 1 });
|
||||
ValuePtr rightInputValue = new Value(new NDArrayView(rightInputValueShape, rightInputData, true));
|
||||
|
||||
NDShape outputShape = plusFunc->Output().Shape().AppendShape({ 1 });
|
||||
NDShape outputShape = plusFunc->Output().Shape().AppendShape({ 1, 1 });
|
||||
std::vector<ElementType> outputData(outputShape.TotalSize());
|
||||
ValuePtr outputValue = new Value(new NDArrayView(outputShape, outputData, false));
|
||||
|
||||
|
|
|
@ -15,18 +15,6 @@ using namespace Microsoft::MSR::CNTK;
|
|||
|
||||
namespace CNTK
|
||||
{
|
||||
static std::pair<size_t, size_t> GetMatrixDimensions(const NDShape& viewShape)
|
||||
{
|
||||
// Ensure none of the shape dimensions are unknown
|
||||
if (viewShape.HasInferredDimension())
|
||||
InvalidArgument("Cannot create an NDArrayView using a view shape that has unknown dimensions for any of it's axes!");
|
||||
|
||||
size_t matrixRowSize = (viewShape.NumAxes() > 0) ? viewShape[0] : 1;
|
||||
size_t matrixColSize = (viewShape.NumAxes() > 0) ? viewShape.SubShape(1).TotalSize() : 1;
|
||||
|
||||
return{ matrixRowSize, matrixColSize };
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
static void* AllocateTensorView(const NDShape& viewShape,
|
||||
const DeviceDescriptor& device,
|
||||
|
@ -103,30 +91,34 @@ namespace CNTK
|
|||
}
|
||||
|
||||
template <typename ElementType>
|
||||
std::shared_ptr<Matrix<ElementType>> NDArrayView::GetMatrixImpl() const
|
||||
std::shared_ptr<Matrix<ElementType>> NDArrayView::GetMatrixImpl(size_t rowColSplitPoint) const
|
||||
{
|
||||
const TensorView<ElementType>* tensorView = GetTensorView<ElementType>();
|
||||
auto tensorShape = tensorView->GetShape();
|
||||
if (tensorShape.GetRank() <= 2)
|
||||
return tensorView->AsMatrix();
|
||||
|
||||
// Determine the split point
|
||||
std::vector<bool> dimsToDrop(tensorShape.GetRank(), false);
|
||||
for (size_t k = 1; k < tensorShape.GetRank(); ++k)
|
||||
if (tensorShape.CanFlatten(k))
|
||||
dimsToDrop[k - 1] = true;
|
||||
size_t splitPoint = rowColSplitPoint;
|
||||
if (splitPoint == AutoSelectRowColSplitPoint)
|
||||
{
|
||||
// Determine the split point
|
||||
std::vector<bool> dimsToDrop(tensorShape.GetRank(), false);
|
||||
for (size_t k = 1; k < tensorShape.GetRank(); ++k)
|
||||
if (tensorShape.CanFlatten(k))
|
||||
dimsToDrop[k - 1] = true;
|
||||
|
||||
// There should be at most 2 dims we cannot drop
|
||||
auto numDimsThatCannotBeDropped = std::count_if(dimsToDrop.begin(), dimsToDrop.end(), [](const bool& val) {
|
||||
return !val;
|
||||
});
|
||||
// There should be at most 2 dims we cannot drop
|
||||
auto numDimsThatCannotBeDropped = std::count_if(dimsToDrop.begin(), dimsToDrop.end(), [](const bool& val) {
|
||||
return !val;
|
||||
});
|
||||
|
||||
if (numDimsThatCannotBeDropped > 2)
|
||||
LogicError("The TensorView underlying this NDArrayView cannot be flattened to a Matrix");
|
||||
if (numDimsThatCannotBeDropped > 2)
|
||||
LogicError("The TensorView underlying this NDArrayView cannot be flattened to a Matrix");
|
||||
|
||||
size_t splitPoint = 0;
|
||||
while (dimsToDrop[splitPoint])
|
||||
splitPoint++;
|
||||
splitPoint = 0;
|
||||
while (dimsToDrop[splitPoint])
|
||||
splitPoint++;
|
||||
}
|
||||
|
||||
tensorShape.FlattenTo2DInPlace(splitPoint, "NDArrayView::GetMatrix");
|
||||
|
||||
|
@ -134,18 +126,18 @@ namespace CNTK
|
|||
}
|
||||
|
||||
template <typename ElementType>
|
||||
std::shared_ptr<const Matrix<ElementType>> NDArrayView::GetMatrix() const
|
||||
std::shared_ptr<const Matrix<ElementType>> NDArrayView::GetMatrix(size_t rowColSplitPoint/* = AutoSelectRowColSplitPoint*/) const
|
||||
{
|
||||
return GetMatrixImpl<ElementType>();
|
||||
return GetMatrixImpl<ElementType>(rowColSplitPoint);
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
std::shared_ptr<Matrix<ElementType>> NDArrayView::GetWritableMatrix()
|
||||
std::shared_ptr<Matrix<ElementType>> NDArrayView::GetWritableMatrix(size_t rowColSplitPoint/* = AutoSelectRowColSplitPoint*/)
|
||||
{
|
||||
if (IsReadOnly())
|
||||
LogicError("NDArrayView::GetWritableMatrix: Cannot get writable Matrix from a read-only NDArrayView");
|
||||
|
||||
return GetMatrixImpl<ElementType>();
|
||||
return GetMatrixImpl<ElementType>(rowColSplitPoint);
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
|
@ -168,7 +160,7 @@ namespace CNTK
|
|||
|
||||
NDArrayViewPtr NDArrayView::DeepClone(bool readOnly/* = false*/) const
|
||||
{
|
||||
NDArrayViewPtr newView = new NDArrayView(this->DataType(), this->Shape(), nullptr, 0, this->Device(), readOnly);
|
||||
NDArrayViewPtr newView(new NDArrayView(this->DataType(), this->Shape(), nullptr, 0, this->Device(), readOnly), [](_ReferenceCounter* ptr) { delete ptr; });
|
||||
switch (m_dataType)
|
||||
{
|
||||
case DataType::Float:
|
||||
|
@ -242,9 +234,7 @@ namespace CNTK
|
|||
}
|
||||
|
||||
auto aliasView = new NDArrayView(DataType(), Device(), StorageFormat(), Shape(), IsReadOnly() || readOnly, tensorView);;
|
||||
return NDArrayViewPtr(aliasView, [](_ReferenceCounter* ptr) {
|
||||
delete ptr;
|
||||
});
|
||||
return NDArrayViewPtr(aliasView, [](_ReferenceCounter* ptr) { delete ptr; });
|
||||
}
|
||||
|
||||
// TODO: This could actually be strided?
|
||||
|
@ -278,18 +268,16 @@ namespace CNTK
|
|||
auto tensorView = new TensorView<ElementType>(randomUniformMatrix, AsTensorShape(shape));
|
||||
|
||||
auto view = new NDArrayView(GetDataType<ElementType>(), device, StorageFormat::Dense, shape, false, tensorView);
|
||||
return NDArrayViewPtr(view, [](_ReferenceCounter* ptr) {
|
||||
delete ptr;
|
||||
});
|
||||
return NDArrayViewPtr(view, [](_ReferenceCounter* ptr) { delete ptr; });
|
||||
}
|
||||
|
||||
// Explicit template instantiations
|
||||
template CNTK_API NDArrayViewPtr NDArrayView::RandomUniform<float>(const NDShape& shape, double rangeStart, double rangeEnd, unsigned long seed, const DeviceDescriptor& device/* = DeviceDescriptor::DefaultDevice()*/);
|
||||
template CNTK_API NDArrayViewPtr NDArrayView::RandomUniform<double>(const NDShape& shape, double rangeStart, double rangeEnd, unsigned long seed, const DeviceDescriptor& device/* = DeviceDescriptor::DefaultDevice()*/);
|
||||
|
||||
template CNTK_API const float* NDArrayView::DataBuffer<float>() const;
|
||||
template CNTK_API const double* NDArrayView::DataBuffer<double>() const;
|
||||
|
||||
template CNTK_API float* NDArrayView::WritableDataBuffer<float>();
|
||||
template CNTK_API double* NDArrayView::WritableDataBuffer<double>();
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,109 @@
|
|||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
|
||||
#include "stdafx.h"
|
||||
#include "CNTKLibrary.h"
|
||||
#include "Utils.h"
|
||||
#include "Matrix.h"
|
||||
#include <algorithm>
|
||||
#include "TensorShape.h"
|
||||
|
||||
using namespace Microsoft::MSR::CNTK;
|
||||
|
||||
namespace CNTK
|
||||
{
|
||||
static Matrix<char>* AllocateMatrix(const NDShape& viewShape, const DeviceDescriptor& device)
|
||||
{
|
||||
auto matrixDims = GetMatrixDimensions(viewShape);
|
||||
auto maskMatrix = new Matrix<char>(matrixDims.first, matrixDims.second, AsCNTKImplDeviceId(device));
|
||||
maskMatrix->SetValue(1);
|
||||
|
||||
return maskMatrix;
|
||||
}
|
||||
|
||||
NDMask::NDMask(const NDShape& shape, Matrix<char>* matrix)
|
||||
: m_device(AsDeviceDescriptor(matrix->GetDeviceId())), m_maskShape(shape), m_matrixView(matrix)
|
||||
{
|
||||
}
|
||||
|
||||
NDMask::NDMask(const NDShape& shape, const DeviceDescriptor& device/* = DeviceDescriptor::DefaultDevice()*/)
|
||||
: NDMask(shape, AllocateMatrix(shape, device))
|
||||
{
|
||||
if (shape.NumAxes() > 2)
|
||||
LogicError("NDMask instances with more than 2 axes are currently unsupported");
|
||||
}
|
||||
|
||||
NDMask::~NDMask()
|
||||
{
|
||||
delete m_matrixView;
|
||||
}
|
||||
|
||||
void NDMask::MaskSection(const std::vector<size_t>& sectionOffset, const NDShape& sectionShape)
|
||||
{
|
||||
// TODO: Implement batching of masking operation for masks residing on GPUs to avoid making
|
||||
// GPU invocations for each MaskSection call.
|
||||
|
||||
if (sectionOffset.size() > m_maskShape.NumAxes())
|
||||
LogicError("NDMask::MaskSection: The sectionOffset cannot have dimensionality higher than the number of axes of 'this' mask");
|
||||
|
||||
if (sectionShape.NumAxes() > m_maskShape.NumAxes())
|
||||
LogicError("NDMask::MaskSection: The section shape cannot have an axes count higher than the number of axes of 'this' mask");
|
||||
|
||||
std::vector<size_t> offset(m_maskShape.NumAxes(), 0);
|
||||
for (size_t i = 0; i < sectionOffset.size(); ++i)
|
||||
offset[i] = sectionOffset[i];
|
||||
|
||||
NDShape shape = sectionShape.AppendShape(NDShape(m_maskShape.NumAxes() - sectionShape.NumAxes(), NDShape::InferredDimension));
|
||||
|
||||
auto maskMatrix = GetMatrix();
|
||||
size_t rowOffset = offset[0];
|
||||
size_t colOffset = offset[1];
|
||||
size_t sliceRowLength = (shape[0] != NDShape::InferredDimension) ? shape[0] : (maskMatrix->GetNumRows() - rowOffset);
|
||||
size_t sliceColLength = (shape[1] != NDShape::InferredDimension) ? shape[1] : (maskMatrix->GetNumCols() - colOffset);
|
||||
if ((rowOffset == 0) && (sliceRowLength == maskMatrix->GetNumRows()))
|
||||
maskMatrix->ColumnSlice(colOffset, sliceColLength).SetValue(0);
|
||||
else
|
||||
{
|
||||
// Since Matrix does not support strides in the row dimension, we will need to create separate slices for each column
|
||||
for (size_t i = colOffset; i < (colOffset + sliceColLength); ++i)
|
||||
{
|
||||
auto column = maskMatrix->ColumnSlice(i, 1);
|
||||
column.Reshape(1, maskMatrix->GetNumRows());
|
||||
column.ColumnSlice(rowOffset, sliceRowLength).SetValue(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void NDMask::Clear()
|
||||
{
|
||||
GetMatrix()->SetValue(1);
|
||||
}
|
||||
|
||||
Matrix<char>* NDMask::GetMatrix() const
|
||||
{
|
||||
return m_matrixView;
|
||||
}
|
||||
|
||||
void NDMask::CopyFrom(const NDMask& source)
|
||||
{
|
||||
if (source.Shape() != Shape())
|
||||
InvalidArgument("NDMask::CopyFrom: The 'source' mask's shape must be same as the shape of this NDMask");
|
||||
|
||||
GetMatrix()->AssignValuesOf(*source.GetMatrix());
|
||||
}
|
||||
|
||||
NDMaskPtr NDMask::DeepClone() const
|
||||
{
|
||||
NDMaskPtr newMask = new NDMask(this->Shape(), this->Device());
|
||||
newMask->CopyFrom(*this);
|
||||
|
||||
return NDMaskPtr(newMask, [](_ReferenceCounter* ptr) { delete ptr; });
|
||||
}
|
||||
|
||||
NDMaskPtr NDMask::Alias() const
|
||||
{
|
||||
return NDMaskPtr(new NDMask(this->Shape(), new Matrix<char>(GetMatrix()->AsReference())), [](_ReferenceCounter* ptr) { delete ptr; });
|
||||
}
|
||||
}
|
|
@ -215,6 +215,7 @@ namespace CNTK
|
|||
// Explicit template instantiations
|
||||
template class _SimpleSet<FunctionPtr>;
|
||||
template class _SimpleSet<Variable>;
|
||||
template class _SimpleSet<const Function*>;
|
||||
|
||||
template bool operator==(const _SimpleSet<Variable>& first, const _SimpleSet<Variable>& second);
|
||||
|
||||
|
@ -314,6 +315,7 @@ namespace CNTK
|
|||
|
||||
// Explicit template instantiations
|
||||
template class _SimpleMap<Variable, const ValuePtr>;
|
||||
template class _SimpleMap<Placeholder, Variable>;
|
||||
|
||||
#pragma endregion _SimpleMap
|
||||
|
||||
|
@ -351,4 +353,14 @@ namespace CNTK
|
|||
{
|
||||
return (*m_dictionaryData)[key];
|
||||
}
|
||||
|
||||
DictionaryValue Dictionary::operator[](const wchar_t* key) const
|
||||
{
|
||||
return m_dictionaryData->at(key);
|
||||
}
|
||||
|
||||
bool Dictionary::Contains(const wchar_t* key) const
|
||||
{
|
||||
return (m_dictionaryData->find(key) != m_dictionaryData->end());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -233,9 +233,22 @@ namespace CNTK
|
|||
return operator[](key.c_str());
|
||||
}
|
||||
|
||||
private:
|
||||
DictionaryValue& operator[](const wchar_t* key);
|
||||
|
||||
DictionaryValue operator[](const std::wstring& key) const
|
||||
{
|
||||
return operator[](key.c_str());
|
||||
}
|
||||
|
||||
DictionaryValue operator[](const wchar_t* key) const;
|
||||
|
||||
bool Contains(const std::wstring& key) const
|
||||
{
|
||||
return Contains(key.c_str());
|
||||
}
|
||||
|
||||
bool Contains(const wchar_t* key) const;
|
||||
|
||||
private:
|
||||
std::unordered_map<std::wstring, DictionaryValue>* m_dictionaryData;
|
||||
};
|
||||
|
@ -308,4 +321,16 @@ namespace CNTK
|
|||
|
||||
return shapeString + "]";
|
||||
}
|
||||
|
||||
inline std::pair<size_t, size_t> GetMatrixDimensions(const NDShape& viewShape)
|
||||
{
|
||||
// Ensure none of the shape dimensions are unknown
|
||||
if (viewShape.HasInferredDimension())
|
||||
InvalidArgument("Cannot create an NDArrayView using a view shape that has unknown dimensions for any of it's axes!");
|
||||
|
||||
size_t matrixRowSize = (viewShape.NumAxes() > 0) ? viewShape[0] : 1;
|
||||
size_t matrixColSize = (viewShape.NumAxes() > 0) ? viewShape.SubShape(1).TotalSize() : 1;
|
||||
|
||||
return{ matrixRowSize, matrixColSize };
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,19 +8,125 @@
|
|||
namespace CNTK
|
||||
{
|
||||
Value::Value(const NDArrayViewPtr& data)
|
||||
: m_data(data)
|
||||
: Value(data, nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
Value::Value(const NDArrayViewPtr& data, const NDMaskPtr& mask)
|
||||
: m_data(data), m_mask(mask)
|
||||
{
|
||||
if ((mask != nullptr) && (mask->Shape().NumAxes() > data->Shape().NumAxes()))
|
||||
InvalidArgument("The number of axes of the mask of a Value object cannot exceed the number of axes of the data NDArrayView object");
|
||||
|
||||
if (mask != nullptr)
|
||||
{
|
||||
auto dataShape = data->Shape();
|
||||
auto maskShape = mask->Shape();
|
||||
if (dataShape.SubShape(dataShape.NumAxes() - maskShape.NumAxes()) != maskShape)
|
||||
InvalidArgument("Invalid Value object; the data and mask are incompatible. The trailing dimensions of the data do not match the dimensions of the mask");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
/*static*/ ValuePtr Value::Create(const NDShape& sampleShape, const std::vector<const std::vector<ElementType>>& sequences, const DeviceDescriptor& device, bool readOnly/* = false*/)
|
||||
{
|
||||
size_t numSequences = sequences.size();
|
||||
std::vector<size_t> sequenceLengths(numSequences);
|
||||
size_t sampleSize = sampleShape.TotalSize();
|
||||
size_t maxSequenceLength = 0;
|
||||
bool needsMask = false;
|
||||
for (size_t i = 0; i < numSequences; ++i)
|
||||
{
|
||||
sequenceLengths[i] = sequences[i].size() / sampleSize;
|
||||
if (maxSequenceLength < sequenceLengths[i])
|
||||
maxSequenceLength = sequenceLengths[i];
|
||||
|
||||
if ((i > 0) && (sequenceLengths[i - 1] != sequenceLengths[i]))
|
||||
needsMask = true;
|
||||
}
|
||||
|
||||
NDShape valueDataShape = sampleShape.AppendShape({ maxSequenceLength, numSequences });
|
||||
NDArrayViewPtr valueData(new NDArrayView(GetDataType<ElementType>(), valueDataShape, DeviceDescriptor::CPUDevice()), [](_ReferenceCounter* ptr) { delete ptr; });
|
||||
ElementType* dataBuffer = valueData->WritableDataBuffer<ElementType>();
|
||||
for (size_t i = 0; i < numSequences; ++i)
|
||||
std::copy(sequences[i].data(), sequences[i].data() + sequences[i].size(), dataBuffer + (maxSequenceLength * i * sampleSize));
|
||||
|
||||
NDArrayViewPtr deviceValueData;
|
||||
if (device == DeviceDescriptor::CPUDevice())
|
||||
{
|
||||
if (readOnly)
|
||||
deviceValueData = valueData->Alias(true);
|
||||
else
|
||||
deviceValueData = valueData;
|
||||
}
|
||||
else
|
||||
{
|
||||
deviceValueData = NDArrayViewPtr(new NDArrayView(GetDataType<ElementType>(), valueDataShape, device), [](_ReferenceCounter* ptr) { delete ptr; });
|
||||
deviceValueData->CopyFrom(*valueData);
|
||||
if (readOnly)
|
||||
deviceValueData = deviceValueData->Alias(true);
|
||||
}
|
||||
|
||||
NDMaskPtr deviceValueMask;
|
||||
if (needsMask)
|
||||
{
|
||||
NDShape valueMaskShape = { maxSequenceLength, numSequences };
|
||||
deviceValueMask = NDMaskPtr(new NDMask(valueMaskShape, device), [](_ReferenceCounter* ptr) {delete ptr; });
|
||||
for (size_t i = 0; i < numSequences; ++i)
|
||||
deviceValueMask->MaskSection({ sequenceLengths[i], i }, { NDShape::InferredDimension, 1 });
|
||||
}
|
||||
|
||||
return ValuePtr(new Value(deviceValueData, deviceValueMask), [](_ReferenceCounter* ptr) { delete ptr; });
|
||||
}
|
||||
|
||||
/*virtual*/ Value::~Value()
|
||||
{
|
||||
}
|
||||
|
||||
/*virtual*/ NDArrayViewPtr Value::Data() const
|
||||
{
|
||||
if (m_data == nullptr)
|
||||
LogicError("The data NDArrayView underlying this 'Value' object is null!");
|
||||
|
||||
// TODO: Check if this is a derived type and throw an exception in that case
|
||||
return m_data;
|
||||
}
|
||||
|
||||
/*virtual*/ NDMaskPtr Value::Mask() const
|
||||
{
|
||||
// TODO: Check if this is a derived type and throw an exception in that case
|
||||
return m_mask;
|
||||
}
|
||||
|
||||
/*virtual*/ ValuePtr Value::DeepClone(bool readOnly/* = false*/) const
|
||||
{
|
||||
// TODO: Check if this is a derived type and throw an exception in that case
|
||||
return ValuePtr(new Value(Data()->DeepClone(readOnly), (Mask() != nullptr) ? Mask()->DeepClone() : nullptr), [](_ReferenceCounter* ptr) { delete ptr; });
|
||||
}
|
||||
|
||||
/*virtual*/ ValuePtr Value::Alias(bool readOnly/* = false*/) const
|
||||
{
|
||||
// TODO: Check if this is a derived type and throw an exception in that case
|
||||
return ValuePtr(new Value(Data()->Alias(readOnly), (Mask() != nullptr) ? Mask()->Alias() : nullptr), [](_ReferenceCounter* ptr) { delete ptr; });
|
||||
}
|
||||
|
||||
/*virtual*/ void Value::CopyFrom(const Value& source)
|
||||
{
|
||||
// TODO: Check if this is a derived type and throw an exception in that case
|
||||
Data()->CopyFrom(*source.Data());
|
||||
if ((Mask() == nullptr) && (source.Mask() != nullptr))
|
||||
InvalidArgument("Value::CopyFrom: Invalid source object; Cannot copy a Value with a mask into 'this' Value that does not have a mask.");
|
||||
|
||||
if (source.Mask() != nullptr)
|
||||
Mask()->CopyFrom(*source.Mask());
|
||||
else
|
||||
{
|
||||
if (Mask() != nullptr)
|
||||
{
|
||||
// Clear the mask
|
||||
Mask()->Clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Explicit template instantiations
|
||||
template /*static*/ CNTK_API ValuePtr Value::Create<float>(const NDShape& sampleShape, const std::vector<const std::vector<float>>& sequences, const DeviceDescriptor& device, bool readOnly/* = false*/);
|
||||
template /*static*/ CNTK_API ValuePtr Value::Create<double>(const NDShape& sampleShape, const std::vector<const std::vector<double>>& sequences, const DeviceDescriptor& device, bool readOnly/* = false*/);
|
||||
}
|
||||
|
|
|
@ -196,7 +196,6 @@ public:
|
|||
// packing algorithm
|
||||
// - width: maximum width of structure; set to maximum over sequence lengths
|
||||
// - inputSequences: vector of input SequenceInfo records (only seqId and GetNumTimeSteps() are used)
|
||||
// - [out] *pMBLayout: MBLayout that describes the created packed sequence set
|
||||
// - placement, rowAllocations: temp buffers (passed in to be able to optimize memory allocations)
|
||||
template<typename SequenceInfoVector>
|
||||
void InitAsPackedSequences(const SequenceInfoVector& inputSequences,
|
||||
|
@ -255,6 +254,12 @@ public:
|
|||
|
||||
size_t GetNumTimeSteps() const { return m_numTimeSteps; }
|
||||
size_t GetNumParallelSequences() const { return m_numParallelSequences; }
|
||||
size_t GetNumSequences() const
|
||||
{
|
||||
return std::count_if(m_sequences.begin(), m_sequences.end(), [](const SequenceInfo& sequence) {
|
||||
return sequence.seqId != GAP_SEQUENCE_ID;
|
||||
});
|
||||
}
|
||||
|
||||
// axis names are for now only a debugging aid
|
||||
// In the future, there will be a mechanism to denote that axes are meant to be the same.
|
||||
|
|
|
@ -6313,8 +6313,10 @@ template void CPUMatrix<char>::SetValue(CPUMatrix<char> const&);
|
|||
//template void CPUMatrix<char>::SetValue(GPUSparseMatrix<char> const&);
|
||||
template void CPUMatrix<char>::RequireSize(const size_t numRows, const size_t numCols, bool growOnly);
|
||||
template void CPUMatrix<char>::Resize(const size_t numRows, const size_t numCols, bool growOnly);
|
||||
template char* CPUMatrix<char>::CopyToArray(void) const;
|
||||
|
||||
template void CPUMatrix<char>::CopySection(size_t numRows, size_t numCols, char* dst, size_t colStride) const;
|
||||
template void CPUMatrix<char>::Reshape(const size_t, const size_t);
|
||||
|
||||
template CPUMatrix<int>::CPUMatrix(const size_t, const size_t, int*, const size_t);
|
||||
template CPUMatrix<int>::~CPUMatrix();
|
||||
|
|
|
@ -4410,6 +4410,7 @@ template void GPUMatrix<char>::SetValue(GPUMatrix<char> const&);
|
|||
//template void GPUMatrix<char>::SetValue(GPUSparseMatrix<char> const&);
|
||||
|
||||
template void GPUMatrix<char>::CopySection(size_t numRows, size_t numCols, char* dst, size_t colStride) const;
|
||||
template void GPUMatrix<char>::Reshape(const size_t, const size_t);
|
||||
|
||||
template GPUMatrix<int>::GPUMatrix(const size_t, const size_t, int, int*, const size_t);
|
||||
template GPUMatrix<int>::~GPUMatrix();
|
||||
|
|
|
@ -2676,6 +2676,7 @@ template GPUSparseMatrix<char>::~GPUSparseMatrix();
|
|||
template GPUSparseMatrix<char> GPUSparseMatrix<char>::ColumnSlice(size_t, size_t) const;
|
||||
template GPUMatrix<char> GPUSparseMatrix<char>::CopyColumnSliceToDense(size_t, size_t) const;
|
||||
template GPUSparseMatrix<char>& GPUSparseMatrix<char>::operator=(GPUSparseMatrix<char>&&);
|
||||
template void GPUSparseMatrix<char>::Reshape(const size_t, const size_t);
|
||||
|
||||
template GPUSparseMatrix<int>::GPUSparseMatrix(DEVICEID_TYPE, const MatrixFormat);
|
||||
template GPUSparseMatrix<int>::~GPUSparseMatrix();
|
||||
|
|
|
@ -5404,6 +5404,8 @@ template void Matrix<char>::SetValue(const Matrix<char>&);
|
|||
template void Matrix<char>::AssignValuesOf (const Matrix<char>&);
|
||||
template bool Matrix<char>::IsEmpty() const;
|
||||
template void Matrix<char>::Resize(const size_t numRows, const size_t numCols, const size_t numNZElemToReserve, bool growOnly);
|
||||
template void Matrix<char>::Reshape(const size_t, const size_t);
|
||||
template char* Matrix<char>::CopyToArray(void) const;
|
||||
|
||||
template Matrix<int>::Matrix(const size_t, const size_t, int*, DEVICEID_TYPE, const size_t, const size_t);
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче