CNTK v2 library: Fix a bug in convolution inference and other misc bug fixes
This commit is contained in:
Родитель
fce86cd958
Коммит
e6036b2735
|
@ -884,8 +884,8 @@ namespace CNTK
|
|||
CNTK_API static const std::wstring StaticAxisNamePrefix;
|
||||
|
||||
CNTK_API static const int SentinelStaticAxisIndexValueForDynamicAxes;
|
||||
static const int SentinelStaticAxisIndexValueForAllStaticAxes;
|
||||
static const int SentinelStaticAxisIndexValueForUnknownAxes;
|
||||
CNTK_API static const int SentinelStaticAxisIndexValueForAllStaticAxes;
|
||||
CNTK_API static const int SentinelStaticAxisIndexValueForUnknownAxes;
|
||||
|
||||
class UniqueDynamicAxesNames
|
||||
{
|
||||
|
@ -953,7 +953,20 @@ namespace CNTK
|
|||
///
|
||||
/// Returns a boolean indicating if 'this' Axis corresponds to a static axis
|
||||
///
|
||||
bool IsStaticAxis() const { return m_staticAxisIdx != SentinelStaticAxisIndexValueForDynamicAxes; }
|
||||
bool IsStaticAxis() const
|
||||
{
|
||||
return ((m_staticAxisIdx != SentinelStaticAxisIndexValueForDynamicAxes) &&
|
||||
(m_staticAxisIdx != SentinelStaticAxisIndexValueForAllStaticAxes) &&
|
||||
(m_staticAxisIdx != SentinelStaticAxisIndexValueForUnknownAxes));
|
||||
}
|
||||
|
||||
///
|
||||
/// Returns a boolean indicating if 'this' Axis corresponds to a dynamic axis
|
||||
///
|
||||
bool IsDynamicAxis() const
|
||||
{
|
||||
return (m_staticAxisIdx == SentinelStaticAxisIndexValueForDynamicAxes);
|
||||
}
|
||||
|
||||
///
|
||||
/// Returns a boolean indicating if 'this' Axis is ordered; i.e. if there is an ordering between the dimensions along this axis.
|
||||
|
@ -1017,11 +1030,11 @@ namespace CNTK
|
|||
|
||||
inline bool operator==(const Axis& first, const Axis& second)
|
||||
{
|
||||
if (first.IsStaticAxis() != second.IsStaticAxis())
|
||||
if (first.IsDynamicAxis() != second.IsDynamicAxis())
|
||||
return false;
|
||||
|
||||
if (first.IsStaticAxis())
|
||||
return first.StaticAxisIndex() == second.StaticAxisIndex();
|
||||
if (!first.IsDynamicAxis())
|
||||
return first.StaticAxisIndex(/*checked =*/ false) == second.StaticAxisIndex(/*checked =*/ false);
|
||||
else
|
||||
return first.Name() == second.Name();
|
||||
}
|
||||
|
@ -1915,14 +1928,14 @@ private:
|
|||
///
|
||||
template<typename ElemType>
|
||||
Parameter(const NDShape& shape, ElemType initValue, const DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice(), const std::wstring& name = L"")
|
||||
: Variable(shape, VariableKind::Parameter, AsDataType<ElemType>(), MakeSharedObject<NDArrayView>(initValue, shape, device), true, {}, name, Internal::GenerateUid(VariableKind::Parameter))
|
||||
: Parameter(shape, AsDataType<ElemType>(), ConstantInitializer(initValue), device, name)
|
||||
{}
|
||||
|
||||
///
|
||||
/// Construct a constant of specified shape whose contents are initialized with the specified 'initValue'
|
||||
///
|
||||
Parameter(const NDShape& shape, DataType dataType, double initValue, const DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice(), const std::wstring& name = L"")
|
||||
: Variable(shape, VariableKind::Parameter, dataType, MakeSharedObject<NDArrayView>(initValue, dataType, shape, device), true, {}, name, Internal::GenerateUid(VariableKind::Parameter))
|
||||
: Parameter(shape, dataType, ConstantInitializer(initValue), device, name)
|
||||
{}
|
||||
|
||||
///
|
||||
|
@ -1962,7 +1975,7 @@ private:
|
|||
// However we have a couple of derivatives of the type to extend the base interface and thus we ensure that the derived types do not have additional fields.
|
||||
// This check is weak in that the derives types may sneak in some additional fields if the base type had some padding at the end, without changing the object size
|
||||
// but it should be good enough for catching any accidental additon of fields.
|
||||
static_assert(sizeof(Parameter) == sizeof(Variable), "The Parameter type should not have any data fields beyond what it's base type 'Variable' has.");
|
||||
static_assert(sizeof(Parameter) == sizeof(Variable), "The Parameter type should not have any data fields beyond what its base type 'Variable' has.");
|
||||
|
||||
///
|
||||
/// Denotes Constant inputs of a Function.
|
||||
|
@ -1993,14 +2006,14 @@ private:
|
|||
///
|
||||
template<typename ElemType>
|
||||
Constant(const NDShape& shape, ElemType initValue, const DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice(), const std::wstring& name = L"")
|
||||
: Variable(shape, VariableKind::Constant, AsDataType<ElemType>(), MakeSharedObject<NDArrayView>(initValue, shape, device), false, {}, name, Internal::GenerateUid(VariableKind::Constant))
|
||||
: Constant(shape, AsDataType<ElemType>(), ConstantInitializer(initValue), device, name)
|
||||
{}
|
||||
|
||||
///
|
||||
/// Construct a constant of specified shape whose contents are initialized with the specified 'initValue'
|
||||
///
|
||||
Constant(const NDShape& shape, DataType dataType, double initValue, const DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice(), const std::wstring& name = L"")
|
||||
: Variable(shape, VariableKind::Constant, dataType, MakeSharedObject<NDArrayView>(initValue, dataType, shape, device), false, {}, name, Internal::GenerateUid(VariableKind::Constant))
|
||||
: Constant(shape, dataType, ConstantInitializer(initValue), device, name)
|
||||
{}
|
||||
|
||||
///
|
||||
|
@ -2042,13 +2055,23 @@ private:
|
|||
Constant(const NDArrayViewPtr& value, const std::wstring& name, const std::wstring& uid)
|
||||
: Variable(value->Shape(), VariableKind::Constant, value->GetDataType(), value->DeepClone(true), false, {}, name, uid)
|
||||
{}
|
||||
|
||||
///
|
||||
/// Construct a constant of specified shape whose contents are initialized using the specified initializer
|
||||
///
|
||||
Constant(const NDShape& shape, DataType dataType, const ParameterInitializer& initializer, const DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice(), const std::wstring& name = L"")
|
||||
: Variable(shape, VariableKind::Constant, dataType, nullptr, false, {}, name, Internal::GenerateUid(VariableKind::Parameter))
|
||||
{
|
||||
m_dataFields->SetValueInitialization(initializer, device);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
// Implementation note: The Variable type is a value type and not polymorphic in nature.
|
||||
// However we have a couple of derivatives of the type to extend the base interface and thus we ensure that the derived types do not have additional fields.
|
||||
// This check is weak in that the derives types may sneak in some additional fields if the base type had some padding at the end, without changing the object size
|
||||
// but it should be good enough for catching any accidental additon of fields.
|
||||
static_assert(sizeof(Constant) == sizeof(Variable), "The Constant type should not have any data fields beyond what it's base type 'Variable' has.");
|
||||
static_assert(sizeof(Constant) == sizeof(Variable), "The Constant type should not have any data fields beyond what its base type 'Variable' has.");
|
||||
}
|
||||
|
||||
namespace std {
|
||||
|
|
|
@ -1209,7 +1209,7 @@ namespace CNTK
|
|||
InvalidArgument("Variable%S with unknown shape detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str());
|
||||
|
||||
if (variable.Shape().HasInferredDimension())
|
||||
InvalidArgument("Variable%S with InferredDimension for at least one axis in it's shape, detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str());
|
||||
InvalidArgument("Variable%S with InferredDimension for at least one axis in its shape, detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str());
|
||||
|
||||
if (variable.DynamicAxes() == Axis::UnknownDynamicAxes)
|
||||
InvalidArgument("Variable%S with unknown dynamic axes detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str());
|
||||
|
@ -1629,10 +1629,14 @@ namespace CNTK
|
|||
|
||||
ComputationNetworkBuilder<ElementType> builder(*m_computationNetwork);
|
||||
|
||||
// TODO: We current only support one backprop root
|
||||
// TODO: We currently only support one backprop root
|
||||
if (backpropRoots.size() > 1)
|
||||
LogicError("More than one backprop roots is currently unsupported");
|
||||
|
||||
auto placeholders = Placeholders();
|
||||
if (!placeholders.empty())
|
||||
InvalidArgument("All placeholders of a Function must be bound before performing a Forward computation on the Function!");
|
||||
|
||||
// Now recursively create the network in a top-down fashion
|
||||
auto rootFunction = RootFunction();
|
||||
auto rootFunctionOutputs = rootFunction->Outputs();
|
||||
|
@ -2478,9 +2482,6 @@ namespace CNTK
|
|||
|
||||
FunctionPtr TransposeAxes(const Variable& operand, const Axis& axis1, const Axis& axis2, const std::wstring& name)
|
||||
{
|
||||
if (!axis1.IsStaticAxis() || !axis2.IsStaticAxis())
|
||||
LogicError("TransposeAxes currently does not support transposing dynamic axes");
|
||||
|
||||
auto additionalProperties = Dictionary();
|
||||
additionalProperties[PrimitiveFunction::AttributeNameAxis1] = axis1;
|
||||
additionalProperties[PrimitiveFunction::AttributeNameAxis2] = axis2;
|
||||
|
@ -2674,9 +2675,9 @@ namespace CNTK
|
|||
InvalidArgument("ClassificationError: The topN argument must be > 0!");
|
||||
|
||||
if (topN == 1)
|
||||
{
|
||||
{
|
||||
if (axis == Axis(0))
|
||||
return Minus(Constant::Scalar(prediction.GetDataType(), 1.0), TransposeTimes(labels, Hardmax(prediction)), name);
|
||||
return Minus(Constant::Scalar(prediction.GetDataType(), 1.0), TransposeTimes(labels, Hardmax(prediction)), name);
|
||||
else
|
||||
{
|
||||
auto axMax = ReduceMax(prediction, axis);
|
||||
|
@ -2686,7 +2687,7 @@ namespace CNTK
|
|||
auto capErr = GreaterEqual(axErr, Constant::Scalar(prediction.GetDataType(), 1.0));
|
||||
return ReduceMean(capErr, Axis::AllStaticAxes(), name);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (axis != Axis(0))
|
||||
|
@ -2872,7 +2873,7 @@ namespace CNTK
|
|||
{
|
||||
void VerifyIsSequence(const Variable& operand)
|
||||
{
|
||||
// The operand must have at least one dynamic axis and it's first dynamic axis must be ordered
|
||||
// The operand must have at least one dynamic axis and its first dynamic axis must be ordered
|
||||
if (operand.DynamicAxes().empty() || !operand.DynamicAxes()[0].IsOrdered())
|
||||
InvalidArgument("A sequence function can only be applied on operands with at least one dynamic axis and whose first dynamic axis is ordered");
|
||||
}
|
||||
|
|
|
@ -330,7 +330,7 @@ namespace CNTK
|
|||
{
|
||||
bool anyParameterOperandDimsInferred = false;
|
||||
auto updateOperandShapeFunc = [](Variable& operand, const NDShape& newOperandShape) {
|
||||
if (operand.IsParameter() && (operand.Shape() != newOperandShape))
|
||||
if ((operand.IsParameter() || operand.IsConstant()) && (operand.Shape() != newOperandShape))
|
||||
{
|
||||
operand.m_dataFields->m_shape = newOperandShape;
|
||||
return true;
|
||||
|
@ -536,6 +536,14 @@ namespace CNTK
|
|||
// infer reduction dimensions if not given
|
||||
// If kernel has a lower rank than the input then the remaining dimensions are to be reduced over.
|
||||
size_t filterRank = kernelShape.Rank();
|
||||
|
||||
// If the trailing axis dimensionality of the kernel shape is NDShape::InferredDimension, we reduce over it by
|
||||
// picking the corresponding operand shape dimensionality
|
||||
// This is done by shrinking the filter rank and let the dimensions be inferred from the operand's shape
|
||||
// TODO: Should we do this for all of the axes in kernelShape that have a dimensionailty of NDShape::InferredDimension?
|
||||
if (kernelShape[filterRank - 1] == NDShape::InferredDimension)
|
||||
filterRank--;
|
||||
|
||||
size_t inputRank = operandShape.Rank();
|
||||
NDShape fromShape;
|
||||
if (op == PrimitiveOpType::Convolution)
|
||||
|
@ -815,7 +823,7 @@ namespace CNTK
|
|||
Microsoft::MSR::CNTK::ComputationNetworkPtr m_computationNetwork;
|
||||
|
||||
// The backpropRoots sepecified in the most recent 'Forward' call on 'this' Function.
|
||||
// This indicates for which of it's roots has 'this' Function retained required intermediate
|
||||
// This indicates for which of its roots has 'this' Function retained required intermediate
|
||||
// states from the previos Forward call to be able to backpropagate gradients backwards from in
|
||||
// the next 'Backward' call.
|
||||
std::unordered_set<Variable> m_currentBackpropRoots;
|
||||
|
|
|
@ -178,7 +178,7 @@ namespace CNTK
|
|||
stream >> isOrderedDynamicAxis;
|
||||
|
||||
Axis* axisPtr = nullptr;
|
||||
if (Axis(staticAxisIdx).IsStaticAxis())
|
||||
if (!Axis(staticAxisIdx).IsDynamicAxis())
|
||||
axisPtr = new Axis(staticAxisIdx);
|
||||
else
|
||||
axisPtr = new Axis(axisName, isOrderedDynamicAxis);
|
||||
|
|
|
@ -155,7 +155,7 @@ namespace CNTK
|
|||
{
|
||||
// Ensure none of the shape dimensions are unknown
|
||||
if (viewShape.HasInferredDimension())
|
||||
InvalidArgument("Cannot create an NDArrayView using a view shape that has unknown dimensions for any of it's axes!");
|
||||
InvalidArgument("Cannot create an NDArrayView using a view shape that has unknown dimensions for any of its axes!");
|
||||
|
||||
size_t matrixRowSize = (viewShape.Rank() > 0) ? viewShape[0] : 1;
|
||||
size_t matrixColSize = (viewShape.Rank() > 0) ? viewShape.SubShape(1).TotalSize() : 1;
|
||||
|
@ -439,8 +439,8 @@ namespace CNTK
|
|||
|
||||
inline std::vector<Axis> GetDerivedDynamicAxes(const Axis& sourceAxis, size_t multiplicativeFactor, int additiveFactor)
|
||||
{
|
||||
if (sourceAxis.IsStaticAxis())
|
||||
LogicError("Static axes cannot be derived from to create new dynamic axes!");
|
||||
if (!sourceAxis.IsDynamicAxis())
|
||||
LogicError("Only dynamic axes can be derived from to create new dynamic axes!");
|
||||
|
||||
if ((multiplicativeFactor == 0) && (additiveFactor == 0))
|
||||
LogicError("Zero size dynamic axes are not allowed!");
|
||||
|
|
|
@ -746,7 +746,9 @@ bool BestGpu::LockDevice(int deviceId, bool trial)
|
|||
std::unique_ptr<CrossProcessMutex> mutex(new CrossProcessMutex(buffer));
|
||||
if (!mutex->Acquire(/*wait=*/false)) // GPU not available
|
||||
{
|
||||
fprintf(stderr, "LockDevice: Failed to lock GPU %d for exclusive use.\n", deviceId);
|
||||
if (GetMathLibTraceLevel() > 0)
|
||||
fprintf(stderr, "LockDevice: Failed to lock GPU %d for exclusive use.\n", deviceId);
|
||||
|
||||
return false;
|
||||
}
|
||||
else
|
||||
|
|
|
@ -41,7 +41,7 @@ typedef unsigned char byte;
|
|||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
MATH_API void SetMathLibTraceLevel(int traceLevel);
|
||||
int GetMathLibTraceLevel();
|
||||
MATH_API int GetMathLibTraceLevel();
|
||||
|
||||
class MATH_API TracingGPUMemoryAllocator
|
||||
{
|
||||
|
|
|
@ -1514,11 +1514,11 @@ void GPUMatrix<ElemType>::RequireSize(const size_t numRows, const size_t numCols
|
|||
template <class ElemType>
|
||||
void GPUMatrix<ElemType>::Resize(const size_t numRows, const size_t numCols, bool growOnly)
|
||||
{
|
||||
VerifyResizable(__func__);
|
||||
|
||||
if (GetNumRows() == numRows && GetNumCols() == numCols)
|
||||
return;
|
||||
|
||||
VerifyResizable(__func__);
|
||||
|
||||
size_t numElements = numRows * numCols;
|
||||
if (numElements > GetSizeAllocated() || // grow allocation
|
||||
(!growOnly && numElements != GetSizeAllocated())) // shrink allocation if not growOnly
|
||||
|
|
|
@ -101,7 +101,7 @@ void TestReduceSum(size_t sampleRank, const DeviceDescriptor& device)
|
|||
{
|
||||
auto testReduceSum = [&sequences, &sequenceLengths, inputShape, sequencesValue, device](const Axis& axis)
|
||||
{
|
||||
if (axis.IsStaticAxis())
|
||||
if (!axis.IsDynamicAxis())
|
||||
RuntimeError("Called the dynamic axis ReduceSum test with a static axis");
|
||||
|
||||
size_t maxActualSequenceLength = sequencesValue->Shape()[inputShape.Rank()];
|
||||
|
@ -219,7 +219,7 @@ void TestSlice(size_t sampleRank, const DeviceDescriptor& device)
|
|||
{
|
||||
auto testDynamicAxisSlice = [&sequences, &sequenceLengths, inputShape, sequencesValue, device](const Axis& axis, int beginOffset, int endOffset)
|
||||
{
|
||||
if (axis.IsStaticAxis())
|
||||
if (!axis.IsDynamicAxis())
|
||||
RuntimeError("Called the dynamic axis slice test with a static axis");
|
||||
|
||||
size_t maxActualSequenceLength = sequencesValue->Shape()[inputShape.Rank()];
|
||||
|
|
|
@ -287,7 +287,7 @@ def pooling(operand, pooling_type, pooling_window_shape, strides=(1,), auto_padd
|
|||
|
||||
@typemap
|
||||
def batch_normalization(operand, scale, bias, running_mean, running_inv_std, spatial,
|
||||
normalization_time_constant=0, blend_time_constant=0,
|
||||
normalization_time_constant=5000, blend_time_constant=0,
|
||||
epsilon=0.00001, use_cudnn_engine=False, name=''):
|
||||
'''
|
||||
Normalizes layer outputs for every minibatch for each output (feature) independently
|
||||
|
@ -305,9 +305,8 @@ def batch_normalization(operand, scale, bias, running_mean, running_inv_std, spa
|
|||
running_inv_std: running variance. Represented as ``running_mean``
|
||||
spatial(`bool`): flag that indicates whether to compute mean/var for each feature in a minibatch
|
||||
independently or, in case of convolutional layers, per future map
|
||||
normalization_time_constant(`float`, default 0): time constant for computing running average of
|
||||
mean and variance as a low-pass filtered version of the batch statistics. Note: the default is not
|
||||
typically what you want
|
||||
normalization_time_constant(`float`, default 5000): time constant for computing running average of
|
||||
mean and variance as a low-pass filtered version of the batch statistics.
|
||||
blend_time_constant(`float`, default 0): constant for smoothing batch estimates with the running
|
||||
statistics
|
||||
epsilon: conditioner constant added to the variance when computing the inverse standard deviation
|
||||
|
@ -1817,7 +1816,7 @@ def input_variable(shape, data_type=np.float32, needs_gradient=True, is_sparse=F
|
|||
|
||||
|
||||
@typemap
|
||||
def placeholder_variable(shape, dynamic_axes=Axis.default_input_variable_dynamic_axes, name=''):
|
||||
def placeholder_variable(shape=None, dynamic_axes=None, name=''):
|
||||
'''
|
||||
It creates a variable place holder for recurrence networks, when the network's dynamic axes
|
||||
are unfolded, the place holder will get assigned a variable along the correspondent dynamic axis.
|
||||
|
@ -1829,8 +1828,16 @@ def placeholder_variable(shape, dynamic_axes=Axis.default_input_variable_dynamic
|
|||
Returns:
|
||||
:class:`cntk.ops.functions.Function`
|
||||
'''
|
||||
from cntk.cntk_py import placeholder_variable
|
||||
shape = sanitize_shape(shape)
|
||||
from cntk.cntk_py import placeholder_variable, NDShape, Axis
|
||||
|
||||
if shape is None:
|
||||
shape = NDShape.unknown.dimensions()
|
||||
else:
|
||||
shape = sanitize_shape(shape)
|
||||
|
||||
if dynamic_axes is None:
|
||||
dynamic_axes = Axis.unknown_dynamic_axes
|
||||
|
||||
dynamic_axes = sanitize_dynamic_axes(dynamic_axes)
|
||||
return placeholder_variable(shape, name, dynamic_axes)
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче