This commit is contained in:
Nikos Karampatziakis 2017-12-06 18:19:02 -08:00
Родитель 3dc66d0304
Коммит 5c97bd02ab
13 изменённых файлов: 310 добавлений и 14 удалений

Просмотреть файл

@ -3825,6 +3825,18 @@ namespace CNTK
///
CNTK_API FunctionPtr Hardmax(const Variable& operand, const std::wstring& name = L"");
///
/// Create an instance of the CNTK built-in top k operation over the first static axis on a
/// specified tensor input operand
///
CNTK_API FunctionPtr TopK(const Variable& operand, size_t k, const std::wstring& name = L"");
///
/// Create an instance of the CNTK built-in top k operation over the specified axis on a
/// specified tensor input operand
///
CNTK_API FunctionPtr TopK(const Variable& operand, size_t k, const Axis& axis, const std::wstring& name = L"");
///
/// Create an instance of the CNTK built-in transpose dimensions operation on specified tensor input operand
///

Просмотреть файл

@ -683,6 +683,12 @@ namespace CNTK
case PrimitiveOpType::Hardmax:
computationNodePtr = New<HardmaxNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::TopK:
{
auto k = functionConfig[PrimitiveFunction::AttributeNameNumItems].Value<size_t>();
computationNodePtr = New<TopKNode<ElementType>>(network->GetDeviceId(), internalNodeName, k);
break;
}
case PrimitiveOpType::StableSigmoid:
computationNodePtr = New<StableSigmoidNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;

Просмотреть файл

@ -1287,6 +1287,40 @@ namespace CNTK
return UnaryOp(PrimitiveOpType::Hardmax, operand, Dictionary(), name);
}
FunctionPtr TopK(const Variable& operand, size_t k, const std::wstring& name)
{
auto additionalProperties = Dictionary();
additionalProperties[PrimitiveFunction::AttributeNameAxis] = Axis(0);
additionalProperties[PrimitiveFunction::AttributeNameNumItems] = k;
return UnaryOp(PrimitiveOpType::TopK, operand, std::move(additionalProperties), name);
}
FunctionPtr TopK(const Variable& operand, size_t k, const Axis& axis, const std::wstring& name)
{
if (!axis.IsStaticAxis())
LogicError("TopK operation only supports a single static axis.");
if (axis.StaticAxisIndex() == 0)
return TopK(operand, k, name);
else
{
auto additionalProperties = Dictionary();
additionalProperties[PrimitiveFunction::AttributeNameAxis] = axis;
additionalProperties[PrimitiveFunction::AttributeNameNumItems] = k;
auto operandPlaceholder = PlaceholderVariable();
auto firstAxis = Axis(0);
auto swapped = TransposeAxes(operandPlaceholder, firstAxis, axis);
auto topkSwapped = TopK(swapped, k, name);
auto outputs = topkSwapped->Outputs();
auto topkValues = TransposeAxes(outputs[0], firstAxis, axis);
auto topkIndices = TransposeAxes(outputs[1], firstAxis, axis);
auto result = Combine({ topkValues , topkIndices });
return AsBlock(std::move(result), { { operandPlaceholder, operand } }, std::move(additionalProperties), L"TopK", name);
}
}
FunctionPtr TransposeAxes(const Variable& operand, const Axis& axis1, const Axis& axis2, const std::wstring& name)
{
auto additionalProperties = Dictionary();

Просмотреть файл

@ -1041,6 +1041,17 @@ namespace CNTK
}
break;
}
case PrimitiveOpType::TopK:
{
assert(m_inputs.size() == 1);
auto k = m_attributes[PrimitiveFunction::AttributeNameNumItems].Value<size_t>();
outputShape = m_inputs[0].Shape();
if (outputShape.Rank() > 0)
outputShape[0] = k;
else if (k != 1)
RuntimeError("Function '%S': cannot get k>1 items from a scalar.", AsString().c_str());
break;
}
default:
LogicError("Specified Primitive Function op %S is not supported", PrimitiveOpTypeName(m_op).c_str());
break;
@ -1060,6 +1071,11 @@ namespace CNTK
outputs.push_back(maskOutput);
}
}
else if (m_op == PrimitiveOpType::TopK)
{
auto IndexOutput = OutputVariable(outputShape, outputDataType, outputDynamicAxes, /*needsGradient =*/ false, Name().empty() ? L"" : Name() + L"_TopKIndexMask");
outputs.push_back(IndexOutput);
}
}
}

Просмотреть файл

@ -112,6 +112,7 @@ namespace CNTK
{PrimitiveOpType::ToBatch, L"ToBatchAxis"},
{PrimitiveOpType::Pad, L"Pad"},
{PrimitiveOpType::Crop, L"Crop"},
{PrimitiveOpType::TopK, L"TopK"},
};
inline const std::wstring& PrimitiveOpTypeName(PrimitiveOpType opType)
@ -291,6 +292,7 @@ namespace CNTK
static const std::wstring AttributeNameBias;
static const std::wstring AttributeNameDepthRadius;
static const std::wstring AttributeNameCustomAttributes;
static const std::wstring AttributeNameNumItems;
protected:
PrimitiveFunction(PrimitiveOpType op, const std::vector<Variable>& inputs, Dictionary&& functionConfig, const std::wstring& functionName, const std::wstring& uid)
@ -804,7 +806,8 @@ namespace CNTK
// Version 16: Add to_batch/unpack_batch.
// Version 17: Add Pad.
// Version 18: Add Crop node.
static const size_t s_serializationVersion = 18;
// Version 19: Add TopK
static const size_t s_serializationVersion = 19;
};
std::vector<DictionaryValue> GetInputUids(const Function& f);

Просмотреть файл

@ -99,4 +99,5 @@ namespace CNTK
/*static*/ const std::wstring PrimitiveFunction::AttributeNameBias = L"bias";
/*static*/ const std::wstring PrimitiveFunction::AttributeNameDepthRadius = L"depthRadius";
/*static*/ const std::wstring PrimitiveFunction::AttributeNameCustomAttributes = L"customAttributes";
/*static*/ const std::wstring PrimitiveFunction::AttributeNameNumItems = L"numItems";
}

Просмотреть файл

@ -96,6 +96,7 @@ namespace CNTK
Crop = 84,
Atanh = 85,
Asinh = 86,
TopK = 87,
// New op types should only be appended to the end of this list
UnknownOP
// and UnknownOP should always be last.

Просмотреть файл

@ -17,6 +17,7 @@
#include <list>
#include <memory>
#include <algorithm>
#include <numeric>
#include <assert.h>
namespace Microsoft { namespace MSR { namespace CNTK {
@ -446,6 +447,114 @@ public:
template class HardmaxNode<float>;
template class HardmaxNode<double>;
template <class ElemType>
class TopKNode : public ComputationNode<ElemType>, public MultiOutputNode<ElemType>, public NumInputs<1>
{
typedef ComputationNode<ElemType> Base; UsingComputationNodeMembersBoilerplate;
static const std::wstring TypeName() { return L"TopK"; }
public:
TopKNode(DEVICEID_TYPE deviceId, const wstring& name) : Base(deviceId, name), MultiOutputNode<ElemType>(2) {}
TopKNode(DEVICEID_TYPE deviceId, const wstring& name, size_t k)
: Base(deviceId, name), MultiOutputNode<ElemType>(2), m_k(k) {}
virtual void RequestMatricesBeforeForwardProp(MatrixPool& matrixPool) override
{
Base::RequestMatricesBeforeForwardProp(matrixPool);
RequestMatrixFromPool(m_sortedIndices, matrixPool);
}
virtual void ReleaseMatricesAfterBackprop(MatrixPool& matrixPool) override
{
Base::ReleaseMatricesAfterBackprop(matrixPool);
ReleaseMatrixToPool(m_sortedIndices, matrixPool);
}
virtual void /*ComputationNode::*/ ForwardProp(const FrameRange& fr) override
{
#ifdef _MSC_VER
auto& outputValuePtrRef = this->ValuePtrRef();
auto& inputValuePtrRef = Input(0)->ValuePtrRef();
#else
auto& outputValuePtrRef = this->template ValuePtrRef();
auto& inputValuePtrRef = Input(0)->template ValuePtrRef();
#endif
auto dim = Input(0)->GetSampleLayout().GetDimPadded(0);
if (m_k > dim)
LogicError("TopK: number of requested elements k (=%zd) exceeds total number of elements (=%zd) on this axis", m_k, dim);
auto&& topkOutput = outputValuePtrRef->Reshaped(m_k, outputValuePtrRef->GetNumElements() / m_k);
auto&& topkInput = inputValuePtrRef->Reshaped(dim, inputValuePtrRef->GetNumElements() / dim);
topkInput.VectorMax(*m_sortedIndices, topkOutput, true, m_k);
this->m_outputsValue[1]->SetValue(m_sortedIndices->Reshaped(outputValuePtrRef->GetNumRows(), outputValuePtrRef->GetNumCols()));
}
virtual void /*ComputationNode::*/ BackpropTo(const size_t inputIndex, const FrameRange& fr) override
{
// Backpropagation works the same way as for other nodes that take top element(s) such as max pooling.
// The values that are not selected get a gradient of zero, otherwise the gradient is copied to the
// positions that were responsible for the top values. This is a scatter operation.
#ifdef _MSC_VER
auto&& inputGradient = Input(0)->GradientPtrRef();
auto&& outputGradient = GradientPtrRef();
#else
auto&& inputGradient = Input(0)->template GradientPtrRef();
auto&& outputGradient = this->template GradientPtrRef();
#endif
auto&& reshapedInputGradient = inputGradient->Reshaped(1, inputGradient->GetNumElements());
auto&& reshapedOutputGradient = outputGradient->Reshaped(1, outputGradient->GetNumElements());
// The indices take values between 0 and the dimension of the axis over which we compute the top k
// Since the matrix class lacks a scatter that can handle indices arising from gather operations
// over a particular axis of a multidimensional tensor, we patch the indices here so that they look
// as if they were generated from a gather-like operation over a 1-dimensional tensor.
auto numCols = m_sortedIndices->GetNumCols();
if (numCols != 1)
{
CreateMatrixIfNull(m_steps);
auto dim = Input(0)->GetSampleLayout().GetDimPadded(0);
auto tmp = new ElemType[numCols];
std::generate(tmp, tmp + numCols, [i = ElemType(0), dim]() mutable { auto ret = i; i += dim; return ret; });
m_steps->SetValue(1, numCols, this->m_deviceId, tmp);
delete[] tmp;
m_sortedIndices->ScaleAndAdd(ElemType(1), *m_steps, *m_sortedIndices);
}
reshapedInputGradient.DoScatterColumnsOf(ElemType(1), m_sortedIndices->Reshaped(1, m_sortedIndices->GetNumElements()), reshapedOutputGradient, ElemType(1));
}
virtual bool OutputUsedInComputingInputNodesGradients() const override { return false; }
virtual bool InputUsedInComputingInputNodesGradients(size_t childIndex) const override { return false; }
virtual void Validate(bool isFinalValidationPass) override
{
assert(m_inputs.size() == 1);
ComputationNodeBase::Validate(isFinalValidationPass);
InferMBLayoutFromInputsForStandardCase(isFinalValidationPass);
auto&& inputShape = Input(0)->GetSampleLayout();
SmallVector<size_t> outDims = inputShape.GetDims();
if (outDims.size() > 0)
outDims[0] = m_k;
auto outShape = TensorShape(outDims);
SetDims(outShape, Input(0)->HasMBLayout());
this->m_outputsMBLayout[1] = Input(0)->GetMBLayout();
this->m_outputsShape[1] = outShape;
}
private:
shared_ptr<Matrix<ElemType>> m_sortedIndices;
shared_ptr<Matrix<ElemType>> m_steps;
size_t m_k;
};
template class TopKNode<float>;
template class TopKNode<double>;
// -----------------------------------------------------------------------
// If (flag, ifValue, elseValue)
// -----------------------------------------------------------------------

Просмотреть файл

@ -22,6 +22,7 @@
#include <thread>
#include <iostream>
#include <algorithm>
#include <numeric>
#pragma warning(push)
#pragma warning(disable:4244) // 'conversion' conversion from 'type1' to 'type2', possible loss of data
#include <boost/random/normal_distribution.hpp>
@ -3826,23 +3827,19 @@ void CPUMatrix<ElemType>::VectorMax(CPUMatrix<ElemType>& maxIndexes, CPUMatrix<E
else
{
std::vector<int> indices(m);
int i = 0;
std::generate(indices.begin(), indices.end(), [&i]
{
return i++;
});
const ElemType* curVal = Data();
ElemType* curIdx = maxIndexes.Data();
ElemType* curMax = maxValues.Data();
for (int icol = 0; icol < n; icol++, curVal += m, curIdx += topK, curMax += topK)
{
std::iota(indices.begin(), indices.end(), 0);
// Partial sort, descending order.
std::nth_element(indices.begin(), indices.begin() + topK, indices.end(),
[curVal](const int& a, const int& b)
{
return curVal[a] > curVal[b];
});
std::partial_sort(indices.begin(), indices.begin() + topK, indices.end(),
[curVal](const int& a, const int& b)
{
return curVal[a] > curVal[b];
});
// REVIEW alexeyk: the following produces warning (see SCL_SECURE_NO_WARNINGS) so use loop instead.
// std::transform(indices.begin(), indices.begin() + topK, curIdx, [](const int& a) { return static_cast<ElemType>(a); });
for (int i2 = 0; i2 < topK; i2++)

Просмотреть файл

@ -79,7 +79,7 @@
<PreprocessorDefinitions Condition="'$(CNTK_ENABLE_1BitSGD)'=='true' and '!$(IsUWP)'">CNTK_PARALLEL_TRAINING_SUPPORT;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<SDLCheck>true</SDLCheck>
<FavorSizeOrSpeed>Speed</FavorSizeOrSpeed>
<AdditionalOptions>/d2Zi+ %(AdditionalOptions)</AdditionalOptions>
<AdditionalOptions>/d2Zi+ /bigobj %(AdditionalOptions)</AdditionalOptions>
</ClCompile>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="$(CpuOnlyBuild)">
@ -156,4 +156,4 @@
<Target Name="CheckDependencies">
<Error Condition="!$(HasMultiverso) And '$(CNTK_ENABLE_ASGD)'!='false' and '!$(IsUWP)'" Text="CNTK requires Multiverso to build, Please see https://docs.microsoft.com/en-us/cognitive-toolkit/Setup-CNTK-on-Windows#enlisting-in-the-cntk-github-repository for installation instructions." />
</Target>
</Project>
</Project>

Просмотреть файл

@ -346,7 +346,8 @@ void CheckEnumValuesNotModified() {
static_cast<size_t>(PrimitiveOpType::Pad) == 83 &&
static_cast<size_t>(PrimitiveOpType::Crop) == 84 &&
static_cast<size_t>(PrimitiveOpType::Atanh) == 85 &&
static_cast<size_t>(PrimitiveOpType::Asinh) == 86,
static_cast<size_t>(PrimitiveOpType::Asinh) == 86 &&
static_cast<size_t>(PrimitiveOpType::TopK) == 87,
"PrimitiveOpType enum value was modified.");
}

Просмотреть файл

@ -1724,6 +1724,39 @@ def hardmax(x, name=''):
return hardmax(x, name)
@typemap
def top_k(x, k, axis=-1, name=''):
'''
Computes the ``k`` largest values of the input tensor and the corresponding indices
along the specified axis (default the last axis). The returned
:class:`~cntk.ops.functions.Function` has two outputs. The first one
contains the top ``k`` values in sorted order, and the second one contains
the corresponding top ``k`` indices.
Example:
>>> x = C.input_variable(10)
>>> y = C.top_k(-x * C.log(x), 3)
>>> x0 = np.arange(10,dtype=np.float32)*0.1
>>> top = y.eval({x:x0})
>>> top_values = top[y.outputs[0]]
>>> top_indices = top[y.outputs[1]]
>>> top_indices
array([[ 4., 3., 5.]], dtype=float32)
Args:
x: numpy array or any :class:`~cntk.ops.functions.Function` that outputs a tensor
k (int): number of top items to return
axis: axis along which to perform the operation (default: -1)
name (str): the name of the Function instance in the network
Returns:
:class:`~cntk.ops.functions.Function`
'''
from cntk.cntk_py import top_k
x = sanitize_input(x)
axis = sanitize_axis(axis)
return top_k(x, k, axis, name)
@typemap
def exp(x, name=''):
'''

Просмотреть файл

@ -696,3 +696,86 @@ def test_crop():
cropped = C.crop_automatic_with_ancestors(
node_output, node_referent, node_input, node_referent).eval(input_map)
assert np.array_equal(cropped, expected)
@pytest.mark.parametrize("axis", [-2, -1])
def test_topk(axis, device_id, precision):
def sliceit(x, axis):
if axis not in (-2, -1):
raise ValueError("unknown axis %d"%axis)
if axis == -1:
return x[..., -1:-4:-1]
elif axis == -2:
return x[..., -1:-4:-1, :]
def check_topk_values_and_indices(top, y, x):
vals = top[y.outputs[0]]
idxs = top[y.outputs[1]]
for vi,xi in zip(vals, x):
assert np.allclose(vi, sliceit(np.sort(xi, axis=axis), axis))
for idxi,xi in zip(idxs, x):
assert np.allclose(idxi, sliceit(np.argsort(xi, axis=axis), axis))
dt = PRECISION_TO_TYPE[precision]
dev = cntk_device(device_id)
p = C.parameter((10, 20, 30), dtype=dt)
np.random.seed(90210)
p.value = p.value + np.random.randn(*p.shape)
y = C.top_k(p, 3, axis=axis)
top = y.eval({}) # for now run this on the device where the parameter is
assert np.allclose(top[y.outputs[0]], sliceit(np.sort(p.value, axis=axis), axis))
assert np.allclose(top[y.outputs[1]], sliceit(np.argsort(p.value, axis=axis), axis))
q = C.input_variable((5, 6), dtype=dt)
q0 = np.random.randn(2, 5, 6).astype(dt)
y = C.top_k(q, 3, axis=axis)
top = y.eval({q:q0}, device=dev)
check_topk_values_and_indices(top, y, q0)
q = C.sequence.input_variable((5, 6), dtype=dt)
q0 = [np.random.randn(4-i, 5, 6).astype(dt) for i in range(2)]
y = C.top_k(q, 3, axis=axis)
top = y.eval({q:q0}, device=dev)
check_topk_values_and_indices(top, y, q0)
def test_topk_backward(device_id, precision):
def check_grad_last_axis(input, root, indices, output):
d = input.shape[-1]
k = indices.shape[-1]
expected_output = np.zeros_like(input).reshape(-1,d)
ind = np.reshape(indices, (-1,k))
r = np.reshape(root,(-1,k))
assert ind.shape[0] == r.shape[0] == expected_output.shape[0]
for i in range(expected_output.shape[0]):
for j in range(k):
expected_output[i,int(ind[i,j])] = r[i,j]
expected_output = expected_output.reshape(input.shape)
assert np.allclose(output, expected_output)
dt = PRECISION_TO_TYPE[precision]
dev = cntk_device(device_id)
axis=-1
h = C.placeholder()
p = C.parameter((4, 5, 6))
p.value = p.value + np.random.randn(*p.shape)
y = C.top_k(h, 3, axis=axis)
y.replace_placeholder(p)
dy, top = y.forward({}, y.outputs, set([y.outputs[0]]))
indices = top[y.outputs[1]]
root = np.ones_like(indices)
root = root + np.arange(np.prod(root.shape)).reshape(*root.shape)
cg = y.backward(dy, {y.outputs[0]:root}, set([p]))[p]
check_grad_last_axis(p.value, root, indices, cg)
q = C.sequence.input_variable((5,6), needs_gradient=True)
q0 = [np.random.randn(4-i,5,6).astype(dt) for i in range(2)]
y = C.top_k(q, 3, axis=axis)
dy, top = y.forward({q:q0}, y.outputs, set([y.outputs[0]]), device=dev)
indices = top[y.outputs[1]]
root = [np.ones_like(i) + 100 * k + np.arange(np.prod(i.shape)).reshape(*i.shape) for k,i in enumerate(indices)]
cg = y.backward(dy, {y.outputs[0]:root}, set([q]))[q]
for i in range(2):
check_grad_last_axis(q0[i], root[i], indices[i], cg[i])