Add ONNX support for zeros_like, ones_like, and eye_like.
This commit is contained in:
Родитель
5fc6c2a26a
Коммит
22e869ec42
|
@ -4161,7 +4161,12 @@ namespace CNTK
|
|||
/// Create an instance of the expand dims operation on specified tensor input operand, for the specified axis
|
||||
///
|
||||
CNTK_API FunctionPtr ExpandDims(const Variable& operand, const Axis& axis, const std::wstring& name = L"");
|
||||
|
||||
|
||||
//
|
||||
/// Create an instance of a constant-like operation. This produces a tensor with given constant value with the shape and dynamic axes specified by the operand.
|
||||
///
|
||||
CNTK_API FunctionPtr ConstantLike(const Variable& operand, const double value, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of a zeros-like operation. This produces zeros with the shape and dynamic axes specified by the operand.
|
||||
///
|
||||
|
|
|
@ -1714,20 +1714,22 @@ namespace CNTK
|
|||
return AsBlock(std::move(result), { { operandPlaceholder, operand }}, L"ExpandDims", name);
|
||||
}
|
||||
|
||||
FunctionPtr ZerosLike(const Variable& operand, const std::wstring& name)
|
||||
FunctionPtr ConstantLike(const Variable& operand, const double value, const std::wstring& name)
|
||||
{
|
||||
auto additionalProperties = Dictionary();
|
||||
additionalProperties[PrimitiveFunctionAttribute::AttributeNameFillValue] = 0.0;
|
||||
additionalProperties[PrimitiveFunctionAttribute::AttributeNameFillValue] = value;
|
||||
|
||||
return UnaryOp(PrimitiveOpType::ConstantOp, operand, std::move(additionalProperties), name);
|
||||
}
|
||||
|
||||
FunctionPtr ZerosLike(const Variable& operand, const std::wstring& name)
|
||||
{
|
||||
return ConstantLike(operand, 0.0, name);
|
||||
}
|
||||
|
||||
FunctionPtr OnesLike(const Variable& operand, const std::wstring& name)
|
||||
{
|
||||
auto additionalProperties = Dictionary();
|
||||
additionalProperties[PrimitiveFunctionAttribute::AttributeNameFillValue] = 1.0;
|
||||
|
||||
return UnaryOp(PrimitiveOpType::ConstantOp, operand, std::move(additionalProperties), name);
|
||||
return ConstantLike(operand, 1.0, name);
|
||||
}
|
||||
|
||||
FunctionPtr CustomProxyOp(const std::vector<Variable>& operands, const std::wstring& customOp, const NDShape& outputShape, DataType outputType, const std::wstring& name)
|
||||
|
|
|
@ -5743,6 +5743,19 @@ void CNTKToONNXHelper::CopyAttributes(const FunctionPtr& src, onnxruntime::Node*
|
|||
size_t k = src->Attributes()[L"numItems"].Value<size_t>();
|
||||
node->AddAttribute(attributesMap[L"numItems"], static_cast<int64_t>(k));
|
||||
}
|
||||
else if (src->OpName() == L"EyeLikeOp")
|
||||
{
|
||||
bool isOutputSparse = src->Attributes().Contains(L"OutputSparse") ? (bool)src->Attributes()[L"OutputSparse"].Value<bool>() : false;
|
||||
if(isOutputSparse)
|
||||
LogicError("Node '%S': 'OutputSparse' is True. Sparse format export not supported.", src->AsString().c_str());
|
||||
}
|
||||
else if (src->OpName() == L"ConstantOp")
|
||||
{
|
||||
if(!src->Attributes().Contains(L"fillValue"))
|
||||
LogicError("Node '%S': 'fillValue' not present. Cannot export op.", src->AsString().c_str());
|
||||
auto fillValue = static_cast<float>(src->Attributes()[L"fillValue"].Value<double>());
|
||||
node->AddAttribute("value", fillValue);
|
||||
}
|
||||
else if (src->OpName() == L"Crop")
|
||||
{
|
||||
const NDShape& inputShape = src->Inputs()[0].Shape();
|
||||
|
@ -6155,8 +6168,6 @@ onnxruntime::Node* CNTKToONNXHelper::AddNode(const FunctionPtr& src, onnxruntime
|
|||
//
|
||||
if (src->OpName() == L"Times")
|
||||
{
|
||||
if (src->Uid() == L"Times4771")
|
||||
std::cout << "";
|
||||
size_t py_api_output_rank_argument = src->Attributes()[L"outputRank"].Value<size_t>();
|
||||
auto input1Shape = orderedInputs[0]->Shape();
|
||||
auto input2Shape = orderedInputs[1]->Shape();
|
||||
|
|
|
@ -2929,6 +2929,21 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
|
|||
FunctionPtr cntkFunction = TopK(inputs[0], k, axis, ToFixedWStringFromMultiByte(node->Name()));
|
||||
return cntkFunction;
|
||||
}
|
||||
else if (onnxOpName == "EyeLike")
|
||||
{
|
||||
// Only limited import support is provided.
|
||||
FunctionPtr cntkFunction = EyeLike(inputs[0], false, ToFixedWStringFromMultiByte(node->Name()));
|
||||
return cntkFunction;
|
||||
}
|
||||
else if (onnxOpName == "ConstantLike")
|
||||
{
|
||||
// Limited import support implemented. 'shape' attribute
|
||||
// node syntax not supported. Only syntax with input tensor
|
||||
// for shape and 'value' attribute for value is supported.
|
||||
float value = GetNamedAttributeAsFloat(node, "value", 0.0f);
|
||||
FunctionPtr cntkFunction = ConstantLike(inputOperand0, static_cast<double>(value), ToFixedWStringFromMultiByte(node->Name()));
|
||||
return cntkFunction;
|
||||
}
|
||||
else if (onnxOpName == "Crop")
|
||||
{
|
||||
// inputShape: [W, H, C] x [N]
|
||||
|
|
|
@ -477,6 +477,12 @@ namespace ONNX
|
|||
{ L"OneHotOp", { {
|
||||
{ L"OneHotOp", "OneHotEncoder"},
|
||||
} } },
|
||||
{ L"EyeLikeOp",{ {
|
||||
{ L"EyeLikeOp", "EyeLike" },
|
||||
} } },
|
||||
{ L"ConstantOp",{ {
|
||||
{ L"ConstantOp", "ConstantLike" },
|
||||
} } },
|
||||
};
|
||||
|
||||
// given a cntkOpName and cntk attribute OpName which is saved in CNTK::Function's attribute,
|
||||
|
|
|
@ -2062,4 +2062,28 @@ def test_Crop_Manual(tmpdir, dtype):
|
|||
y = C.constant(np.ones((1,2,1), dtype=np.float32))
|
||||
model = C.crop_manual(x, y, 1, 2, name='crop_manual')
|
||||
data = np.asarray(range(4*4), dtype=np.float32).reshape((1,4,4))
|
||||
verify_one_input(model, data, tmpdir, "Crop_Manual_0")
|
||||
verify_one_input(model, data, tmpdir, "Crop_Manual_0")
|
||||
|
||||
# eye_like
|
||||
@pytest.mark.parametrize("dtype", DType_Config)
|
||||
def test_Eye_Like(tmpdir, dtype):
|
||||
x = C.input_variable((4, 4), dynamic_axes=[], dtype=dtype, name='feature')
|
||||
model = C.eye_like(x, sparse_output=False)
|
||||
data = np.asarray(range(4*4), dtype=dtype).reshape((4,4))
|
||||
verify_one_input(model, data, tmpdir, "Eye_Like_0")
|
||||
|
||||
# zeros_like
|
||||
@pytest.mark.parametrize("dtype", DType_Config)
|
||||
def test_Zeros_Like(tmpdir, dtype):
|
||||
x = C.input_variable((3, 4), dynamic_axes=[], dtype=dtype, name='feature')
|
||||
model = C.zeros_like(x, name='zeros_like_op')
|
||||
data = np.asarray(range(3*4), dtype=dtype).reshape((3,4))
|
||||
verify_one_input(model, data, tmpdir, "Zeros_Like_0")
|
||||
|
||||
# ones_like
|
||||
@pytest.mark.parametrize("dtype", DType_Config)
|
||||
def test_Ones_Like(tmpdir, dtype):
|
||||
x = C.input_variable((3, 4), dynamic_axes=[], dtype=dtype, name='feature')
|
||||
model = C.ones_like(x, name='ones_like_op')
|
||||
data = np.asarray(range(3*4), dtype=dtype).reshape((3,4))
|
||||
verify_one_input(model, data, tmpdir, "Ones_Like_0")
|
|
@ -30,6 +30,7 @@ known_issues = [
|
|||
'MVN_1',
|
||||
'MVN_2',
|
||||
'MVN_3',
|
||||
'Eye_Like_0',
|
||||
]
|
||||
|
||||
def parse_single_result_case(case_str):
|
||||
|
|
Загрузка…
Ссылка в новой задаче