Add ONNX support for zeros_like, ones_like, and eye_like.

This commit is contained in:
Spandan Tiwari 2018-12-05 13:51:43 -08:00
Родитель 5fc6c2a26a
Коммит 22e869ec42
7 изменённых файлов: 74 добавлений и 10 удалений

Просмотреть файл

@ -4161,7 +4161,12 @@ namespace CNTK
/// Create an instance of the expand dims operation on specified tensor input operand, for the specified axis
///
CNTK_API FunctionPtr ExpandDims(const Variable& operand, const Axis& axis, const std::wstring& name = L"");
//
/// Create an instance of a constant-like operation. This produces a tensor with given constant value with the shape and dynamic axes specified by the operand.
///
CNTK_API FunctionPtr ConstantLike(const Variable& operand, const double value, const std::wstring& name = L"");
///
/// Create an instance of a zeros-like operation. This produces zeros with the shape and dynamic axes specified by the operand.
///

Просмотреть файл

@ -1714,20 +1714,22 @@ namespace CNTK
return AsBlock(std::move(result), { { operandPlaceholder, operand }}, L"ExpandDims", name);
}
FunctionPtr ZerosLike(const Variable& operand, const std::wstring& name)
FunctionPtr ConstantLike(const Variable& operand, const double value, const std::wstring& name)
{
auto additionalProperties = Dictionary();
additionalProperties[PrimitiveFunctionAttribute::AttributeNameFillValue] = 0.0;
additionalProperties[PrimitiveFunctionAttribute::AttributeNameFillValue] = value;
return UnaryOp(PrimitiveOpType::ConstantOp, operand, std::move(additionalProperties), name);
}
FunctionPtr ZerosLike(const Variable& operand, const std::wstring& name)
{
return ConstantLike(operand, 0.0, name);
}
FunctionPtr OnesLike(const Variable& operand, const std::wstring& name)
{
auto additionalProperties = Dictionary();
additionalProperties[PrimitiveFunctionAttribute::AttributeNameFillValue] = 1.0;
return UnaryOp(PrimitiveOpType::ConstantOp, operand, std::move(additionalProperties), name);
return ConstantLike(operand, 1.0, name);
}
FunctionPtr CustomProxyOp(const std::vector<Variable>& operands, const std::wstring& customOp, const NDShape& outputShape, DataType outputType, const std::wstring& name)

Просмотреть файл

@ -5743,6 +5743,19 @@ void CNTKToONNXHelper::CopyAttributes(const FunctionPtr& src, onnxruntime::Node*
size_t k = src->Attributes()[L"numItems"].Value<size_t>();
node->AddAttribute(attributesMap[L"numItems"], static_cast<int64_t>(k));
}
else if (src->OpName() == L"EyeLikeOp")
{
bool isOutputSparse = src->Attributes().Contains(L"OutputSparse") ? (bool)src->Attributes()[L"OutputSparse"].Value<bool>() : false;
if(isOutputSparse)
LogicError("Node '%S': 'OutputSparse' is True. Sparse format export not supported.", src->AsString().c_str());
}
else if (src->OpName() == L"ConstantOp")
{
if(!src->Attributes().Contains(L"fillValue"))
LogicError("Node '%S': 'fillValue' not present. Cannot export op.", src->AsString().c_str());
auto fillValue = static_cast<float>(src->Attributes()[L"fillValue"].Value<double>());
node->AddAttribute("value", fillValue);
}
else if (src->OpName() == L"Crop")
{
const NDShape& inputShape = src->Inputs()[0].Shape();
@ -6155,8 +6168,6 @@ onnxruntime::Node* CNTKToONNXHelper::AddNode(const FunctionPtr& src, onnxruntime
//
if (src->OpName() == L"Times")
{
if (src->Uid() == L"Times4771")
std::cout << "";
size_t py_api_output_rank_argument = src->Attributes()[L"outputRank"].Value<size_t>();
auto input1Shape = orderedInputs[0]->Shape();
auto input2Shape = orderedInputs[1]->Shape();

Просмотреть файл

@ -2929,6 +2929,21 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
FunctionPtr cntkFunction = TopK(inputs[0], k, axis, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "EyeLike")
{
// Only limited import support is provided.
FunctionPtr cntkFunction = EyeLike(inputs[0], false, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "ConstantLike")
{
// Limited import support implemented. 'shape' attribute
// node syntax not supported. Only syntax with input tensor
// for shape and 'value' attribute for value is supported.
float value = GetNamedAttributeAsFloat(node, "value", 0.0f);
FunctionPtr cntkFunction = ConstantLike(inputOperand0, static_cast<double>(value), ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Crop")
{
// inputShape: [W, H, C] x [N]

Просмотреть файл

@ -477,6 +477,12 @@ namespace ONNX
{ L"OneHotOp", { {
{ L"OneHotOp", "OneHotEncoder"},
} } },
{ L"EyeLikeOp",{ {
{ L"EyeLikeOp", "EyeLike" },
} } },
{ L"ConstantOp",{ {
{ L"ConstantOp", "ConstantLike" },
} } },
};
// given a cntkOpName and cntk attribute OpName which is saved in CNTK::Function's attribute,

Просмотреть файл

@ -2062,4 +2062,28 @@ def test_Crop_Manual(tmpdir, dtype):
y = C.constant(np.ones((1,2,1), dtype=np.float32))
model = C.crop_manual(x, y, 1, 2, name='crop_manual')
data = np.asarray(range(4*4), dtype=np.float32).reshape((1,4,4))
verify_one_input(model, data, tmpdir, "Crop_Manual_0")
verify_one_input(model, data, tmpdir, "Crop_Manual_0")
# eye_like
@pytest.mark.parametrize("dtype", DType_Config)
def test_Eye_Like(tmpdir, dtype):
x = C.input_variable((4, 4), dynamic_axes=[], dtype=dtype, name='feature')
model = C.eye_like(x, sparse_output=False)
data = np.asarray(range(4*4), dtype=dtype).reshape((4,4))
verify_one_input(model, data, tmpdir, "Eye_Like_0")
# zeros_like
@pytest.mark.parametrize("dtype", DType_Config)
def test_Zeros_Like(tmpdir, dtype):
x = C.input_variable((3, 4), dynamic_axes=[], dtype=dtype, name='feature')
model = C.zeros_like(x, name='zeros_like_op')
data = np.asarray(range(3*4), dtype=dtype).reshape((3,4))
verify_one_input(model, data, tmpdir, "Zeros_Like_0")
# ones_like
@pytest.mark.parametrize("dtype", DType_Config)
def test_Ones_Like(tmpdir, dtype):
x = C.input_variable((3, 4), dynamic_axes=[], dtype=dtype, name='feature')
model = C.ones_like(x, name='ones_like_op')
data = np.asarray(range(3*4), dtype=dtype).reshape((3,4))
verify_one_input(model, data, tmpdir, "Ones_Like_0")

Просмотреть файл

@ -30,6 +30,7 @@ known_issues = [
'MVN_1',
'MVN_2',
'MVN_3',
'Eye_Like_0',
]
def parse_single_result_case(case_str):