add top k operation. Closes #2468.
This commit is contained in:
Родитель
3dc66d0304
Коммит
5c97bd02ab
|
@ -3825,6 +3825,18 @@ namespace CNTK
|
|||
///
|
||||
CNTK_API FunctionPtr Hardmax(const Variable& operand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in top k operation over the first static axis on a
|
||||
/// specified tensor input operand
|
||||
///
|
||||
CNTK_API FunctionPtr TopK(const Variable& operand, size_t k, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in top k operation over the specified axis on a
|
||||
/// specified tensor input operand
|
||||
///
|
||||
CNTK_API FunctionPtr TopK(const Variable& operand, size_t k, const Axis& axis, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in transpose dimensions operation on specified tensor input operand
|
||||
///
|
||||
|
|
|
@ -683,6 +683,12 @@ namespace CNTK
|
|||
case PrimitiveOpType::Hardmax:
|
||||
computationNodePtr = New<HardmaxNode<ElementType>>(network->GetDeviceId(), internalNodeName);
|
||||
break;
|
||||
case PrimitiveOpType::TopK:
|
||||
{
|
||||
auto k = functionConfig[PrimitiveFunction::AttributeNameNumItems].Value<size_t>();
|
||||
computationNodePtr = New<TopKNode<ElementType>>(network->GetDeviceId(), internalNodeName, k);
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::StableSigmoid:
|
||||
computationNodePtr = New<StableSigmoidNode<ElementType>>(network->GetDeviceId(), internalNodeName);
|
||||
break;
|
||||
|
|
|
@ -1287,6 +1287,40 @@ namespace CNTK
|
|||
return UnaryOp(PrimitiveOpType::Hardmax, operand, Dictionary(), name);
|
||||
}
|
||||
|
||||
FunctionPtr TopK(const Variable& operand, size_t k, const std::wstring& name)
|
||||
{
|
||||
auto additionalProperties = Dictionary();
|
||||
additionalProperties[PrimitiveFunction::AttributeNameAxis] = Axis(0);
|
||||
additionalProperties[PrimitiveFunction::AttributeNameNumItems] = k;
|
||||
return UnaryOp(PrimitiveOpType::TopK, operand, std::move(additionalProperties), name);
|
||||
}
|
||||
|
||||
|
||||
FunctionPtr TopK(const Variable& operand, size_t k, const Axis& axis, const std::wstring& name)
|
||||
{
|
||||
if (!axis.IsStaticAxis())
|
||||
LogicError("TopK operation only supports a single static axis.");
|
||||
|
||||
if (axis.StaticAxisIndex() == 0)
|
||||
return TopK(operand, k, name);
|
||||
else
|
||||
{
|
||||
auto additionalProperties = Dictionary();
|
||||
additionalProperties[PrimitiveFunction::AttributeNameAxis] = axis;
|
||||
additionalProperties[PrimitiveFunction::AttributeNameNumItems] = k;
|
||||
|
||||
auto operandPlaceholder = PlaceholderVariable();
|
||||
auto firstAxis = Axis(0);
|
||||
auto swapped = TransposeAxes(operandPlaceholder, firstAxis, axis);
|
||||
auto topkSwapped = TopK(swapped, k, name);
|
||||
auto outputs = topkSwapped->Outputs();
|
||||
auto topkValues = TransposeAxes(outputs[0], firstAxis, axis);
|
||||
auto topkIndices = TransposeAxes(outputs[1], firstAxis, axis);
|
||||
auto result = Combine({ topkValues , topkIndices });
|
||||
return AsBlock(std::move(result), { { operandPlaceholder, operand } }, std::move(additionalProperties), L"TopK", name);
|
||||
}
|
||||
}
|
||||
|
||||
FunctionPtr TransposeAxes(const Variable& operand, const Axis& axis1, const Axis& axis2, const std::wstring& name)
|
||||
{
|
||||
auto additionalProperties = Dictionary();
|
||||
|
|
|
@ -1041,6 +1041,17 @@ namespace CNTK
|
|||
}
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::TopK:
|
||||
{
|
||||
assert(m_inputs.size() == 1);
|
||||
auto k = m_attributes[PrimitiveFunction::AttributeNameNumItems].Value<size_t>();
|
||||
outputShape = m_inputs[0].Shape();
|
||||
if (outputShape.Rank() > 0)
|
||||
outputShape[0] = k;
|
||||
else if (k != 1)
|
||||
RuntimeError("Function '%S': cannot get k>1 items from a scalar.", AsString().c_str());
|
||||
break;
|
||||
}
|
||||
default:
|
||||
LogicError("Specified Primitive Function op %S is not supported", PrimitiveOpTypeName(m_op).c_str());
|
||||
break;
|
||||
|
@ -1060,6 +1071,11 @@ namespace CNTK
|
|||
outputs.push_back(maskOutput);
|
||||
}
|
||||
}
|
||||
else if (m_op == PrimitiveOpType::TopK)
|
||||
{
|
||||
auto IndexOutput = OutputVariable(outputShape, outputDataType, outputDynamicAxes, /*needsGradient =*/ false, Name().empty() ? L"" : Name() + L"_TopKIndexMask");
|
||||
outputs.push_back(IndexOutput);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -112,6 +112,7 @@ namespace CNTK
|
|||
{PrimitiveOpType::ToBatch, L"ToBatchAxis"},
|
||||
{PrimitiveOpType::Pad, L"Pad"},
|
||||
{PrimitiveOpType::Crop, L"Crop"},
|
||||
{PrimitiveOpType::TopK, L"TopK"},
|
||||
};
|
||||
|
||||
inline const std::wstring& PrimitiveOpTypeName(PrimitiveOpType opType)
|
||||
|
@ -291,6 +292,7 @@ namespace CNTK
|
|||
static const std::wstring AttributeNameBias;
|
||||
static const std::wstring AttributeNameDepthRadius;
|
||||
static const std::wstring AttributeNameCustomAttributes;
|
||||
static const std::wstring AttributeNameNumItems;
|
||||
|
||||
protected:
|
||||
PrimitiveFunction(PrimitiveOpType op, const std::vector<Variable>& inputs, Dictionary&& functionConfig, const std::wstring& functionName, const std::wstring& uid)
|
||||
|
@ -804,7 +806,8 @@ namespace CNTK
|
|||
// Version 16: Add to_batch/unpack_batch.
|
||||
// Version 17: Add Pad.
|
||||
// Version 18: Add Crop node.
|
||||
static const size_t s_serializationVersion = 18;
|
||||
// Version 19: Add TopK
|
||||
static const size_t s_serializationVersion = 19;
|
||||
};
|
||||
|
||||
std::vector<DictionaryValue> GetInputUids(const Function& f);
|
||||
|
|
|
@ -99,4 +99,5 @@ namespace CNTK
|
|||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameBias = L"bias";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameDepthRadius = L"depthRadius";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameCustomAttributes = L"customAttributes";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameNumItems = L"numItems";
|
||||
}
|
||||
|
|
|
@ -96,6 +96,7 @@ namespace CNTK
|
|||
Crop = 84,
|
||||
Atanh = 85,
|
||||
Asinh = 86,
|
||||
TopK = 87,
|
||||
// New op types should only be appended to the end of this list
|
||||
UnknownOP
|
||||
// and UnknownOP should always be last.
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include <list>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <assert.h>
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
@ -446,6 +447,114 @@ public:
|
|||
template class HardmaxNode<float>;
|
||||
template class HardmaxNode<double>;
|
||||
|
||||
|
||||
|
||||
template <class ElemType>
|
||||
class TopKNode : public ComputationNode<ElemType>, public MultiOutputNode<ElemType>, public NumInputs<1>
|
||||
{
|
||||
typedef ComputationNode<ElemType> Base; UsingComputationNodeMembersBoilerplate;
|
||||
static const std::wstring TypeName() { return L"TopK"; }
|
||||
|
||||
public:
|
||||
TopKNode(DEVICEID_TYPE deviceId, const wstring& name) : Base(deviceId, name), MultiOutputNode<ElemType>(2) {}
|
||||
TopKNode(DEVICEID_TYPE deviceId, const wstring& name, size_t k)
|
||||
: Base(deviceId, name), MultiOutputNode<ElemType>(2), m_k(k) {}
|
||||
|
||||
virtual void RequestMatricesBeforeForwardProp(MatrixPool& matrixPool) override
|
||||
{
|
||||
Base::RequestMatricesBeforeForwardProp(matrixPool);
|
||||
RequestMatrixFromPool(m_sortedIndices, matrixPool);
|
||||
}
|
||||
|
||||
virtual void ReleaseMatricesAfterBackprop(MatrixPool& matrixPool) override
|
||||
{
|
||||
Base::ReleaseMatricesAfterBackprop(matrixPool);
|
||||
ReleaseMatrixToPool(m_sortedIndices, matrixPool);
|
||||
}
|
||||
|
||||
virtual void /*ComputationNode::*/ ForwardProp(const FrameRange& fr) override
|
||||
{
|
||||
#ifdef _MSC_VER
|
||||
auto& outputValuePtrRef = this->ValuePtrRef();
|
||||
auto& inputValuePtrRef = Input(0)->ValuePtrRef();
|
||||
#else
|
||||
auto& outputValuePtrRef = this->template ValuePtrRef();
|
||||
auto& inputValuePtrRef = Input(0)->template ValuePtrRef();
|
||||
#endif
|
||||
auto dim = Input(0)->GetSampleLayout().GetDimPadded(0);
|
||||
if (m_k > dim)
|
||||
LogicError("TopK: number of requested elements k (=%zd) exceeds total number of elements (=%zd) on this axis", m_k, dim);
|
||||
|
||||
auto&& topkOutput = outputValuePtrRef->Reshaped(m_k, outputValuePtrRef->GetNumElements() / m_k);
|
||||
auto&& topkInput = inputValuePtrRef->Reshaped(dim, inputValuePtrRef->GetNumElements() / dim);
|
||||
topkInput.VectorMax(*m_sortedIndices, topkOutput, true, m_k);
|
||||
this->m_outputsValue[1]->SetValue(m_sortedIndices->Reshaped(outputValuePtrRef->GetNumRows(), outputValuePtrRef->GetNumCols()));
|
||||
}
|
||||
|
||||
virtual void /*ComputationNode::*/ BackpropTo(const size_t inputIndex, const FrameRange& fr) override
|
||||
{
|
||||
// Backpropagation works the same way as for other nodes that take top element(s) such as max pooling.
|
||||
// The values that are not selected get a gradient of zero, otherwise the gradient is copied to the
|
||||
// positions that were responsible for the top values. This is a scatter operation.
|
||||
#ifdef _MSC_VER
|
||||
auto&& inputGradient = Input(0)->GradientPtrRef();
|
||||
auto&& outputGradient = GradientPtrRef();
|
||||
#else
|
||||
auto&& inputGradient = Input(0)->template GradientPtrRef();
|
||||
auto&& outputGradient = this->template GradientPtrRef();
|
||||
#endif
|
||||
|
||||
auto&& reshapedInputGradient = inputGradient->Reshaped(1, inputGradient->GetNumElements());
|
||||
auto&& reshapedOutputGradient = outputGradient->Reshaped(1, outputGradient->GetNumElements());
|
||||
|
||||
// The indices take values between 0 and the dimension of the axis over which we compute the top k
|
||||
// Since the matrix class lacks a scatter that can handle indices arising from gather operations
|
||||
// over a particular axis of a multidimensional tensor, we patch the indices here so that they look
|
||||
// as if they were generated from a gather-like operation over a 1-dimensional tensor.
|
||||
auto numCols = m_sortedIndices->GetNumCols();
|
||||
if (numCols != 1)
|
||||
{
|
||||
CreateMatrixIfNull(m_steps);
|
||||
auto dim = Input(0)->GetSampleLayout().GetDimPadded(0);
|
||||
auto tmp = new ElemType[numCols];
|
||||
std::generate(tmp, tmp + numCols, [i = ElemType(0), dim]() mutable { auto ret = i; i += dim; return ret; });
|
||||
m_steps->SetValue(1, numCols, this->m_deviceId, tmp);
|
||||
delete[] tmp;
|
||||
m_sortedIndices->ScaleAndAdd(ElemType(1), *m_steps, *m_sortedIndices);
|
||||
}
|
||||
reshapedInputGradient.DoScatterColumnsOf(ElemType(1), m_sortedIndices->Reshaped(1, m_sortedIndices->GetNumElements()), reshapedOutputGradient, ElemType(1));
|
||||
}
|
||||
|
||||
virtual bool OutputUsedInComputingInputNodesGradients() const override { return false; }
|
||||
virtual bool InputUsedInComputingInputNodesGradients(size_t childIndex) const override { return false; }
|
||||
|
||||
virtual void Validate(bool isFinalValidationPass) override
|
||||
{
|
||||
assert(m_inputs.size() == 1);
|
||||
ComputationNodeBase::Validate(isFinalValidationPass);
|
||||
InferMBLayoutFromInputsForStandardCase(isFinalValidationPass);
|
||||
|
||||
auto&& inputShape = Input(0)->GetSampleLayout();
|
||||
SmallVector<size_t> outDims = inputShape.GetDims();
|
||||
if (outDims.size() > 0)
|
||||
outDims[0] = m_k;
|
||||
auto outShape = TensorShape(outDims);
|
||||
SetDims(outShape, Input(0)->HasMBLayout());
|
||||
this->m_outputsMBLayout[1] = Input(0)->GetMBLayout();
|
||||
this->m_outputsShape[1] = outShape;
|
||||
}
|
||||
|
||||
private:
|
||||
shared_ptr<Matrix<ElemType>> m_sortedIndices;
|
||||
shared_ptr<Matrix<ElemType>> m_steps;
|
||||
size_t m_k;
|
||||
};
|
||||
|
||||
template class TopKNode<float>;
|
||||
template class TopKNode<double>;
|
||||
|
||||
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// If (flag, ifValue, elseValue)
|
||||
// -----------------------------------------------------------------------
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <thread>
|
||||
#include <iostream>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable:4244) // 'conversion' conversion from 'type1' to 'type2', possible loss of data
|
||||
#include <boost/random/normal_distribution.hpp>
|
||||
|
@ -3826,23 +3827,19 @@ void CPUMatrix<ElemType>::VectorMax(CPUMatrix<ElemType>& maxIndexes, CPUMatrix<E
|
|||
else
|
||||
{
|
||||
std::vector<int> indices(m);
|
||||
int i = 0;
|
||||
std::generate(indices.begin(), indices.end(), [&i]
|
||||
{
|
||||
return i++;
|
||||
});
|
||||
|
||||
const ElemType* curVal = Data();
|
||||
ElemType* curIdx = maxIndexes.Data();
|
||||
ElemType* curMax = maxValues.Data();
|
||||
for (int icol = 0; icol < n; icol++, curVal += m, curIdx += topK, curMax += topK)
|
||||
{
|
||||
std::iota(indices.begin(), indices.end(), 0);
|
||||
// Partial sort, descending order.
|
||||
std::nth_element(indices.begin(), indices.begin() + topK, indices.end(),
|
||||
[curVal](const int& a, const int& b)
|
||||
{
|
||||
return curVal[a] > curVal[b];
|
||||
});
|
||||
std::partial_sort(indices.begin(), indices.begin() + topK, indices.end(),
|
||||
[curVal](const int& a, const int& b)
|
||||
{
|
||||
return curVal[a] > curVal[b];
|
||||
});
|
||||
// REVIEW alexeyk: the following produces warning (see SCL_SECURE_NO_WARNINGS) so use loop instead.
|
||||
// std::transform(indices.begin(), indices.begin() + topK, curIdx, [](const int& a) { return static_cast<ElemType>(a); });
|
||||
for (int i2 = 0; i2 < topK; i2++)
|
||||
|
|
|
@ -79,7 +79,7 @@
|
|||
<PreprocessorDefinitions Condition="'$(CNTK_ENABLE_1BitSGD)'=='true' and '!$(IsUWP)'">CNTK_PARALLEL_TRAINING_SUPPORT;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<SDLCheck>true</SDLCheck>
|
||||
<FavorSizeOrSpeed>Speed</FavorSizeOrSpeed>
|
||||
<AdditionalOptions>/d2Zi+ %(AdditionalOptions)</AdditionalOptions>
|
||||
<AdditionalOptions>/d2Zi+ /bigobj %(AdditionalOptions)</AdditionalOptions>
|
||||
</ClCompile>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemDefinitionGroup Condition="$(CpuOnlyBuild)">
|
||||
|
@ -156,4 +156,4 @@
|
|||
<Target Name="CheckDependencies">
|
||||
<Error Condition="!$(HasMultiverso) And '$(CNTK_ENABLE_ASGD)'!='false' and '!$(IsUWP)'" Text="CNTK requires Multiverso to build, Please see https://docs.microsoft.com/en-us/cognitive-toolkit/Setup-CNTK-on-Windows#enlisting-in-the-cntk-github-repository for installation instructions." />
|
||||
</Target>
|
||||
</Project>
|
||||
</Project>
|
|
@ -346,7 +346,8 @@ void CheckEnumValuesNotModified() {
|
|||
static_cast<size_t>(PrimitiveOpType::Pad) == 83 &&
|
||||
static_cast<size_t>(PrimitiveOpType::Crop) == 84 &&
|
||||
static_cast<size_t>(PrimitiveOpType::Atanh) == 85 &&
|
||||
static_cast<size_t>(PrimitiveOpType::Asinh) == 86,
|
||||
static_cast<size_t>(PrimitiveOpType::Asinh) == 86 &&
|
||||
static_cast<size_t>(PrimitiveOpType::TopK) == 87,
|
||||
"PrimitiveOpType enum value was modified.");
|
||||
}
|
||||
|
||||
|
|
|
@ -1724,6 +1724,39 @@ def hardmax(x, name=''):
|
|||
return hardmax(x, name)
|
||||
|
||||
|
||||
@typemap
|
||||
def top_k(x, k, axis=-1, name=''):
|
||||
'''
|
||||
Computes the ``k`` largest values of the input tensor and the corresponding indices
|
||||
along the specified axis (default the last axis). The returned
|
||||
:class:`~cntk.ops.functions.Function` has two outputs. The first one
|
||||
contains the top ``k`` values in sorted order, and the second one contains
|
||||
the corresponding top ``k`` indices.
|
||||
|
||||
Example:
|
||||
>>> x = C.input_variable(10)
|
||||
>>> y = C.top_k(-x * C.log(x), 3)
|
||||
>>> x0 = np.arange(10,dtype=np.float32)*0.1
|
||||
>>> top = y.eval({x:x0})
|
||||
>>> top_values = top[y.outputs[0]]
|
||||
>>> top_indices = top[y.outputs[1]]
|
||||
>>> top_indices
|
||||
array([[ 4., 3., 5.]], dtype=float32)
|
||||
|
||||
Args:
|
||||
x: numpy array or any :class:`~cntk.ops.functions.Function` that outputs a tensor
|
||||
k (int): number of top items to return
|
||||
axis: axis along which to perform the operation (default: -1)
|
||||
name (str): the name of the Function instance in the network
|
||||
Returns:
|
||||
:class:`~cntk.ops.functions.Function`
|
||||
'''
|
||||
from cntk.cntk_py import top_k
|
||||
x = sanitize_input(x)
|
||||
axis = sanitize_axis(axis)
|
||||
return top_k(x, k, axis, name)
|
||||
|
||||
|
||||
@typemap
|
||||
def exp(x, name=''):
|
||||
'''
|
||||
|
|
|
@ -696,3 +696,86 @@ def test_crop():
|
|||
cropped = C.crop_automatic_with_ancestors(
|
||||
node_output, node_referent, node_input, node_referent).eval(input_map)
|
||||
assert np.array_equal(cropped, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("axis", [-2, -1])
|
||||
def test_topk(axis, device_id, precision):
|
||||
def sliceit(x, axis):
|
||||
if axis not in (-2, -1):
|
||||
raise ValueError("unknown axis %d"%axis)
|
||||
if axis == -1:
|
||||
return x[..., -1:-4:-1]
|
||||
elif axis == -2:
|
||||
return x[..., -1:-4:-1, :]
|
||||
|
||||
def check_topk_values_and_indices(top, y, x):
|
||||
vals = top[y.outputs[0]]
|
||||
idxs = top[y.outputs[1]]
|
||||
for vi,xi in zip(vals, x):
|
||||
assert np.allclose(vi, sliceit(np.sort(xi, axis=axis), axis))
|
||||
for idxi,xi in zip(idxs, x):
|
||||
assert np.allclose(idxi, sliceit(np.argsort(xi, axis=axis), axis))
|
||||
|
||||
dt = PRECISION_TO_TYPE[precision]
|
||||
dev = cntk_device(device_id)
|
||||
|
||||
p = C.parameter((10, 20, 30), dtype=dt)
|
||||
np.random.seed(90210)
|
||||
p.value = p.value + np.random.randn(*p.shape)
|
||||
y = C.top_k(p, 3, axis=axis)
|
||||
top = y.eval({}) # for now run this on the device where the parameter is
|
||||
assert np.allclose(top[y.outputs[0]], sliceit(np.sort(p.value, axis=axis), axis))
|
||||
assert np.allclose(top[y.outputs[1]], sliceit(np.argsort(p.value, axis=axis), axis))
|
||||
|
||||
q = C.input_variable((5, 6), dtype=dt)
|
||||
q0 = np.random.randn(2, 5, 6).astype(dt)
|
||||
y = C.top_k(q, 3, axis=axis)
|
||||
top = y.eval({q:q0}, device=dev)
|
||||
check_topk_values_and_indices(top, y, q0)
|
||||
|
||||
q = C.sequence.input_variable((5, 6), dtype=dt)
|
||||
q0 = [np.random.randn(4-i, 5, 6).astype(dt) for i in range(2)]
|
||||
y = C.top_k(q, 3, axis=axis)
|
||||
top = y.eval({q:q0}, device=dev)
|
||||
check_topk_values_and_indices(top, y, q0)
|
||||
|
||||
|
||||
def test_topk_backward(device_id, precision):
|
||||
def check_grad_last_axis(input, root, indices, output):
|
||||
d = input.shape[-1]
|
||||
k = indices.shape[-1]
|
||||
expected_output = np.zeros_like(input).reshape(-1,d)
|
||||
ind = np.reshape(indices, (-1,k))
|
||||
r = np.reshape(root,(-1,k))
|
||||
assert ind.shape[0] == r.shape[0] == expected_output.shape[0]
|
||||
for i in range(expected_output.shape[0]):
|
||||
for j in range(k):
|
||||
expected_output[i,int(ind[i,j])] = r[i,j]
|
||||
expected_output = expected_output.reshape(input.shape)
|
||||
assert np.allclose(output, expected_output)
|
||||
|
||||
dt = PRECISION_TO_TYPE[precision]
|
||||
dev = cntk_device(device_id)
|
||||
|
||||
axis=-1
|
||||
h = C.placeholder()
|
||||
p = C.parameter((4, 5, 6))
|
||||
p.value = p.value + np.random.randn(*p.shape)
|
||||
y = C.top_k(h, 3, axis=axis)
|
||||
y.replace_placeholder(p)
|
||||
dy, top = y.forward({}, y.outputs, set([y.outputs[0]]))
|
||||
indices = top[y.outputs[1]]
|
||||
root = np.ones_like(indices)
|
||||
root = root + np.arange(np.prod(root.shape)).reshape(*root.shape)
|
||||
cg = y.backward(dy, {y.outputs[0]:root}, set([p]))[p]
|
||||
check_grad_last_axis(p.value, root, indices, cg)
|
||||
|
||||
q = C.sequence.input_variable((5,6), needs_gradient=True)
|
||||
q0 = [np.random.randn(4-i,5,6).astype(dt) for i in range(2)]
|
||||
y = C.top_k(q, 3, axis=axis)
|
||||
dy, top = y.forward({q:q0}, y.outputs, set([y.outputs[0]]), device=dev)
|
||||
indices = top[y.outputs[1]]
|
||||
root = [np.ones_like(i) + 100 * k + np.arange(np.prod(i.shape)).reshape(*i.shape) for k,i in enumerate(indices)]
|
||||
cg = y.backward(dy, {y.outputs[0]:root}, set([q]))[q]
|
||||
for i in range(2):
|
||||
check_grad_last_axis(q0[i], root[i], indices[i], cg[i])
|
||||
|
|
Загрузка…
Ссылка в новой задаче