CNTK v2 library: Fix a bug in convolution inference and other misc bug fixes

This commit is contained in:
Amit Agarwal 2016-10-20 15:21:27 -07:00
Родитель fce86cd958
Коммит e6036b2735
10 изменённых файлов: 81 добавлений и 40 удалений

Просмотреть файл

@ -884,8 +884,8 @@ namespace CNTK
CNTK_API static const std::wstring StaticAxisNamePrefix;
CNTK_API static const int SentinelStaticAxisIndexValueForDynamicAxes;
static const int SentinelStaticAxisIndexValueForAllStaticAxes;
static const int SentinelStaticAxisIndexValueForUnknownAxes;
CNTK_API static const int SentinelStaticAxisIndexValueForAllStaticAxes;
CNTK_API static const int SentinelStaticAxisIndexValueForUnknownAxes;
class UniqueDynamicAxesNames
{
@ -953,7 +953,20 @@ namespace CNTK
///
/// Returns a boolean indicating if 'this' Axis corresponds to a static axis
///
bool IsStaticAxis() const { return m_staticAxisIdx != SentinelStaticAxisIndexValueForDynamicAxes; }
bool IsStaticAxis() const
{
return ((m_staticAxisIdx != SentinelStaticAxisIndexValueForDynamicAxes) &&
(m_staticAxisIdx != SentinelStaticAxisIndexValueForAllStaticAxes) &&
(m_staticAxisIdx != SentinelStaticAxisIndexValueForUnknownAxes));
}
///
/// Returns a boolean indicating if 'this' Axis corresponds to a dynamic axis
///
bool IsDynamicAxis() const
{
return (m_staticAxisIdx == SentinelStaticAxisIndexValueForDynamicAxes);
}
///
/// Returns a boolean indicating if 'this' Axis is ordered; i.e. if there is an ordering between the dimensions along this axis.
@ -1017,11 +1030,11 @@ namespace CNTK
inline bool operator==(const Axis& first, const Axis& second)
{
if (first.IsStaticAxis() != second.IsStaticAxis())
if (first.IsDynamicAxis() != second.IsDynamicAxis())
return false;
if (first.IsStaticAxis())
return first.StaticAxisIndex() == second.StaticAxisIndex();
if (!first.IsDynamicAxis())
return first.StaticAxisIndex(/*checked =*/ false) == second.StaticAxisIndex(/*checked =*/ false);
else
return first.Name() == second.Name();
}
@ -1915,14 +1928,14 @@ private:
///
template<typename ElemType>
Parameter(const NDShape& shape, ElemType initValue, const DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice(), const std::wstring& name = L"")
: Variable(shape, VariableKind::Parameter, AsDataType<ElemType>(), MakeSharedObject<NDArrayView>(initValue, shape, device), true, {}, name, Internal::GenerateUid(VariableKind::Parameter))
: Parameter(shape, AsDataType<ElemType>(), ConstantInitializer(initValue), device, name)
{}
///
/// Construct a constant of specified shape whose contents are initialized with the specified 'initValue'
///
Parameter(const NDShape& shape, DataType dataType, double initValue, const DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice(), const std::wstring& name = L"")
: Variable(shape, VariableKind::Parameter, dataType, MakeSharedObject<NDArrayView>(initValue, dataType, shape, device), true, {}, name, Internal::GenerateUid(VariableKind::Parameter))
: Parameter(shape, dataType, ConstantInitializer(initValue), device, name)
{}
///
@ -1962,7 +1975,7 @@ private:
// However we have a couple of derivatives of the type to extend the base interface and thus we ensure that the derived types do not have additional fields.
// This check is weak in that the derives types may sneak in some additional fields if the base type had some padding at the end, without changing the object size
// but it should be good enough for catching any accidental additon of fields.
static_assert(sizeof(Parameter) == sizeof(Variable), "The Parameter type should not have any data fields beyond what it's base type 'Variable' has.");
static_assert(sizeof(Parameter) == sizeof(Variable), "The Parameter type should not have any data fields beyond what its base type 'Variable' has.");
///
/// Denotes Constant inputs of a Function.
@ -1993,14 +2006,14 @@ private:
///
template<typename ElemType>
Constant(const NDShape& shape, ElemType initValue, const DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice(), const std::wstring& name = L"")
: Variable(shape, VariableKind::Constant, AsDataType<ElemType>(), MakeSharedObject<NDArrayView>(initValue, shape, device), false, {}, name, Internal::GenerateUid(VariableKind::Constant))
: Constant(shape, AsDataType<ElemType>(), ConstantInitializer(initValue), device, name)
{}
///
/// Construct a constant of specified shape whose contents are initialized with the specified 'initValue'
///
Constant(const NDShape& shape, DataType dataType, double initValue, const DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice(), const std::wstring& name = L"")
: Variable(shape, VariableKind::Constant, dataType, MakeSharedObject<NDArrayView>(initValue, dataType, shape, device), false, {}, name, Internal::GenerateUid(VariableKind::Constant))
: Constant(shape, dataType, ConstantInitializer(initValue), device, name)
{}
///
@ -2042,13 +2055,23 @@ private:
Constant(const NDArrayViewPtr& value, const std::wstring& name, const std::wstring& uid)
: Variable(value->Shape(), VariableKind::Constant, value->GetDataType(), value->DeepClone(true), false, {}, name, uid)
{}
///
/// Construct a constant of specified shape whose contents are initialized using the specified initializer
///
Constant(const NDShape& shape, DataType dataType, const ParameterInitializer& initializer, const DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice(), const std::wstring& name = L"")
: Variable(shape, VariableKind::Constant, dataType, nullptr, false, {}, name, Internal::GenerateUid(VariableKind::Parameter))
{
m_dataFields->SetValueInitialization(initializer, device);
}
};
// Implementation note: The Variable type is a value type and not polymorphic in nature.
// However we have a couple of derivatives of the type to extend the base interface and thus we ensure that the derived types do not have additional fields.
// This check is weak in that the derives types may sneak in some additional fields if the base type had some padding at the end, without changing the object size
// but it should be good enough for catching any accidental additon of fields.
static_assert(sizeof(Constant) == sizeof(Variable), "The Constant type should not have any data fields beyond what it's base type 'Variable' has.");
static_assert(sizeof(Constant) == sizeof(Variable), "The Constant type should not have any data fields beyond what its base type 'Variable' has.");
}
namespace std {

Просмотреть файл

@ -1209,7 +1209,7 @@ namespace CNTK
InvalidArgument("Variable%S with unknown shape detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str());
if (variable.Shape().HasInferredDimension())
InvalidArgument("Variable%S with InferredDimension for at least one axis in it's shape, detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str());
InvalidArgument("Variable%S with InferredDimension for at least one axis in its shape, detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str());
if (variable.DynamicAxes() == Axis::UnknownDynamicAxes)
InvalidArgument("Variable%S with unknown dynamic axes detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str());
@ -1629,10 +1629,14 @@ namespace CNTK
ComputationNetworkBuilder<ElementType> builder(*m_computationNetwork);
// TODO: We current only support one backprop root
// TODO: We currently only support one backprop root
if (backpropRoots.size() > 1)
LogicError("More than one backprop roots is currently unsupported");
auto placeholders = Placeholders();
if (!placeholders.empty())
InvalidArgument("All placeholders of a Function must be bound before performing a Forward computation on the Function!");
// Now recursively create the network in a top-down fashion
auto rootFunction = RootFunction();
auto rootFunctionOutputs = rootFunction->Outputs();
@ -2478,9 +2482,6 @@ namespace CNTK
FunctionPtr TransposeAxes(const Variable& operand, const Axis& axis1, const Axis& axis2, const std::wstring& name)
{
if (!axis1.IsStaticAxis() || !axis2.IsStaticAxis())
LogicError("TransposeAxes currently does not support transposing dynamic axes");
auto additionalProperties = Dictionary();
additionalProperties[PrimitiveFunction::AttributeNameAxis1] = axis1;
additionalProperties[PrimitiveFunction::AttributeNameAxis2] = axis2;
@ -2674,9 +2675,9 @@ namespace CNTK
InvalidArgument("ClassificationError: The topN argument must be > 0!");
if (topN == 1)
{
{
if (axis == Axis(0))
return Minus(Constant::Scalar(prediction.GetDataType(), 1.0), TransposeTimes(labels, Hardmax(prediction)), name);
return Minus(Constant::Scalar(prediction.GetDataType(), 1.0), TransposeTimes(labels, Hardmax(prediction)), name);
else
{
auto axMax = ReduceMax(prediction, axis);
@ -2686,7 +2687,7 @@ namespace CNTK
auto capErr = GreaterEqual(axErr, Constant::Scalar(prediction.GetDataType(), 1.0));
return ReduceMean(capErr, Axis::AllStaticAxes(), name);
}
}
}
else
{
if (axis != Axis(0))
@ -2872,7 +2873,7 @@ namespace CNTK
{
void VerifyIsSequence(const Variable& operand)
{
// The operand must have at least one dynamic axis and it's first dynamic axis must be ordered
// The operand must have at least one dynamic axis and its first dynamic axis must be ordered
if (operand.DynamicAxes().empty() || !operand.DynamicAxes()[0].IsOrdered())
InvalidArgument("A sequence function can only be applied on operands with at least one dynamic axis and whose first dynamic axis is ordered");
}

Просмотреть файл

@ -330,7 +330,7 @@ namespace CNTK
{
bool anyParameterOperandDimsInferred = false;
auto updateOperandShapeFunc = [](Variable& operand, const NDShape& newOperandShape) {
if (operand.IsParameter() && (operand.Shape() != newOperandShape))
if ((operand.IsParameter() || operand.IsConstant()) && (operand.Shape() != newOperandShape))
{
operand.m_dataFields->m_shape = newOperandShape;
return true;
@ -536,6 +536,14 @@ namespace CNTK
// infer reduction dimensions if not given
// If kernel has a lower rank than the input then the remaining dimensions are to be reduced over.
size_t filterRank = kernelShape.Rank();
// If the trailing axis dimensionality of the kernel shape is NDShape::InferredDimension, we reduce over it by
// picking the corresponding operand shape dimensionality
// This is done by shrinking the filter rank and let the dimensions be inferred from the operand's shape
// TODO: Should we do this for all of the axes in kernelShape that have a dimensionailty of NDShape::InferredDimension?
if (kernelShape[filterRank - 1] == NDShape::InferredDimension)
filterRank--;
size_t inputRank = operandShape.Rank();
NDShape fromShape;
if (op == PrimitiveOpType::Convolution)
@ -815,7 +823,7 @@ namespace CNTK
Microsoft::MSR::CNTK::ComputationNetworkPtr m_computationNetwork;
// The backpropRoots sepecified in the most recent 'Forward' call on 'this' Function.
// This indicates for which of it's roots has 'this' Function retained required intermediate
// This indicates for which of its roots has 'this' Function retained required intermediate
// states from the previos Forward call to be able to backpropagate gradients backwards from in
// the next 'Backward' call.
std::unordered_set<Variable> m_currentBackpropRoots;

Просмотреть файл

@ -178,7 +178,7 @@ namespace CNTK
stream >> isOrderedDynamicAxis;
Axis* axisPtr = nullptr;
if (Axis(staticAxisIdx).IsStaticAxis())
if (!Axis(staticAxisIdx).IsDynamicAxis())
axisPtr = new Axis(staticAxisIdx);
else
axisPtr = new Axis(axisName, isOrderedDynamicAxis);

Просмотреть файл

@ -155,7 +155,7 @@ namespace CNTK
{
// Ensure none of the shape dimensions are unknown
if (viewShape.HasInferredDimension())
InvalidArgument("Cannot create an NDArrayView using a view shape that has unknown dimensions for any of it's axes!");
InvalidArgument("Cannot create an NDArrayView using a view shape that has unknown dimensions for any of its axes!");
size_t matrixRowSize = (viewShape.Rank() > 0) ? viewShape[0] : 1;
size_t matrixColSize = (viewShape.Rank() > 0) ? viewShape.SubShape(1).TotalSize() : 1;
@ -439,8 +439,8 @@ namespace CNTK
inline std::vector<Axis> GetDerivedDynamicAxes(const Axis& sourceAxis, size_t multiplicativeFactor, int additiveFactor)
{
if (sourceAxis.IsStaticAxis())
LogicError("Static axes cannot be derived from to create new dynamic axes!");
if (!sourceAxis.IsDynamicAxis())
LogicError("Only dynamic axes can be derived from to create new dynamic axes!");
if ((multiplicativeFactor == 0) && (additiveFactor == 0))
LogicError("Zero size dynamic axes are not allowed!");

Просмотреть файл

@ -746,7 +746,9 @@ bool BestGpu::LockDevice(int deviceId, bool trial)
std::unique_ptr<CrossProcessMutex> mutex(new CrossProcessMutex(buffer));
if (!mutex->Acquire(/*wait=*/false)) // GPU not available
{
fprintf(stderr, "LockDevice: Failed to lock GPU %d for exclusive use.\n", deviceId);
if (GetMathLibTraceLevel() > 0)
fprintf(stderr, "LockDevice: Failed to lock GPU %d for exclusive use.\n", deviceId);
return false;
}
else

Просмотреть файл

@ -41,7 +41,7 @@ typedef unsigned char byte;
namespace Microsoft { namespace MSR { namespace CNTK {
MATH_API void SetMathLibTraceLevel(int traceLevel);
int GetMathLibTraceLevel();
MATH_API int GetMathLibTraceLevel();
class MATH_API TracingGPUMemoryAllocator
{

Просмотреть файл

@ -1514,11 +1514,11 @@ void GPUMatrix<ElemType>::RequireSize(const size_t numRows, const size_t numCols
template <class ElemType>
void GPUMatrix<ElemType>::Resize(const size_t numRows, const size_t numCols, bool growOnly)
{
VerifyResizable(__func__);
if (GetNumRows() == numRows && GetNumCols() == numCols)
return;
VerifyResizable(__func__);
size_t numElements = numRows * numCols;
if (numElements > GetSizeAllocated() || // grow allocation
(!growOnly && numElements != GetSizeAllocated())) // shrink allocation if not growOnly

Просмотреть файл

@ -101,7 +101,7 @@ void TestReduceSum(size_t sampleRank, const DeviceDescriptor& device)
{
auto testReduceSum = [&sequences, &sequenceLengths, inputShape, sequencesValue, device](const Axis& axis)
{
if (axis.IsStaticAxis())
if (!axis.IsDynamicAxis())
RuntimeError("Called the dynamic axis ReduceSum test with a static axis");
size_t maxActualSequenceLength = sequencesValue->Shape()[inputShape.Rank()];
@ -219,7 +219,7 @@ void TestSlice(size_t sampleRank, const DeviceDescriptor& device)
{
auto testDynamicAxisSlice = [&sequences, &sequenceLengths, inputShape, sequencesValue, device](const Axis& axis, int beginOffset, int endOffset)
{
if (axis.IsStaticAxis())
if (!axis.IsDynamicAxis())
RuntimeError("Called the dynamic axis slice test with a static axis");
size_t maxActualSequenceLength = sequencesValue->Shape()[inputShape.Rank()];

Просмотреть файл

@ -287,7 +287,7 @@ def pooling(operand, pooling_type, pooling_window_shape, strides=(1,), auto_padd
@typemap
def batch_normalization(operand, scale, bias, running_mean, running_inv_std, spatial,
normalization_time_constant=0, blend_time_constant=0,
normalization_time_constant=5000, blend_time_constant=0,
epsilon=0.00001, use_cudnn_engine=False, name=''):
'''
Normalizes layer outputs for every minibatch for each output (feature) independently
@ -305,9 +305,8 @@ def batch_normalization(operand, scale, bias, running_mean, running_inv_std, spa
running_inv_std: running variance. Represented as ``running_mean``
spatial(`bool`): flag that indicates whether to compute mean/var for each feature in a minibatch
independently or, in case of convolutional layers, per future map
normalization_time_constant(`float`, default 0): time constant for computing running average of
mean and variance as a low-pass filtered version of the batch statistics. Note: the default is not
typically what you want
normalization_time_constant(`float`, default 5000): time constant for computing running average of
mean and variance as a low-pass filtered version of the batch statistics.
blend_time_constant(`float`, default 0): constant for smoothing batch estimates with the running
statistics
epsilon: conditioner constant added to the variance when computing the inverse standard deviation
@ -1817,7 +1816,7 @@ def input_variable(shape, data_type=np.float32, needs_gradient=True, is_sparse=F
@typemap
def placeholder_variable(shape, dynamic_axes=Axis.default_input_variable_dynamic_axes, name=''):
def placeholder_variable(shape=None, dynamic_axes=None, name=''):
'''
It creates a variable place holder for recurrence networks, when the network's dynamic axes
are unfolded, the place holder will get assigned a variable along the correspondent dynamic axis.
@ -1829,8 +1828,16 @@ def placeholder_variable(shape, dynamic_axes=Axis.default_input_variable_dynamic
Returns:
:class:`cntk.ops.functions.Function`
'''
from cntk.cntk_py import placeholder_variable
shape = sanitize_shape(shape)
from cntk.cntk_py import placeholder_variable, NDShape, Axis
if shape is None:
shape = NDShape.unknown.dimensions()
else:
shape = sanitize_shape(shape)
if dynamic_axes is None:
dynamic_axes = Axis.unknown_dynamic_axes
dynamic_axes = sanitize_dynamic_axes(dynamic_axes)
return placeholder_variable(shape, name, dynamic_axes)