support ONNX1.2.2 cast op
This commit is contained in:
Родитель
bece037f63
Коммит
07fd96f357
|
@ -195,6 +195,11 @@ private:
|
|||
static std::vector<float> INTSToVecFloat(const std::vector<int64_t> &ints);
|
||||
static std::vector<int64_t> ConvertPermutationCNTKToONNX(const std::vector<Axis> &axes, bool hasBatchAxis);
|
||||
|
||||
//
|
||||
// Convert DataType from CNTK TensorProto
|
||||
//
|
||||
static TensorProto_DataType ConvertDataTypeCNTKToTensorProto(CNTK::DataType newDataType);
|
||||
|
||||
//
|
||||
// Convert data types from CNTK to ONNX.
|
||||
//
|
||||
|
@ -1190,24 +1195,29 @@ void MapAndUpdateONNXType(const std::string &op, bool inputArg, int argOrder, CN
|
|||
type.mutable_tensor_type()->set_elem_type(onnx::TensorProto_DataType_FLOAT);
|
||||
}
|
||||
|
||||
void CNTKToONNXHelper::UpdateONNXType(CNTK::DataType dataType, onnx::TypeProto &type)
|
||||
TensorProto_DataType CNTKToONNXHelper::ConvertDataTypeCNTKToTensorProto(
|
||||
CNTK::DataType newDataType)
|
||||
{
|
||||
switch (dataType)
|
||||
// to TensorProto_DataType
|
||||
switch (newDataType)
|
||||
{
|
||||
case CNTK::DataType::Float:
|
||||
type.mutable_tensor_type()->set_elem_type(onnx::TensorProto_DataType_FLOAT);
|
||||
break;
|
||||
case CNTK::DataType::Float16:
|
||||
type.mutable_tensor_type()->set_elem_type(onnx::TensorProto_DataType_FLOAT16);
|
||||
break;
|
||||
return TensorProto_DataType::TensorProto_DataType_FLOAT;
|
||||
case CNTK::DataType::Double:
|
||||
type.mutable_tensor_type()->set_elem_type(onnx::TensorProto_DataType_DOUBLE);
|
||||
break;
|
||||
return TensorProto_DataType::TensorProto_DataType_DOUBLE;
|
||||
case CNTK::DataType::Float16:
|
||||
return TensorProto_DataType::TensorProto_DataType_FLOAT16;
|
||||
default:
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
}
|
||||
|
||||
void CNTKToONNXHelper::UpdateONNXType(CNTK::DataType dataType, onnx::TypeProto &type)
|
||||
{
|
||||
TensorProto_DataType tensorProtoDataType = ConvertDataTypeCNTKToTensorProto(dataType);
|
||||
type.mutable_tensor_type()->set_elem_type(tensorProtoDataType);
|
||||
}
|
||||
|
||||
std::string CNTKToONNXHelper::ToOPName(const FunctionPtr& src)
|
||||
{
|
||||
auto lookup = Operators::CntkToONNXLookup();
|
||||
|
@ -2847,6 +2857,12 @@ void CNTKToONNXHelper::CopyAttributes(const FunctionPtr& src, LotusIR::Node* nod
|
|||
node->AddAttribute("min", minValue);
|
||||
node->AddAttribute("max", maxValue);
|
||||
}
|
||||
else if (src->OpName() == L"Cast")
|
||||
{
|
||||
DataType newDataType = static_cast<DataType>(src->Attributes()[L"newDataType"].Value<int>());
|
||||
int64_t to = static_cast<int64_t>(ConvertDataTypeCNTKToTensorProto(newDataType));
|
||||
node->AddAttribute(attributesMap[L"newDataType"], to);
|
||||
}
|
||||
if (src->OpName() == L"BatchNormalization")
|
||||
{
|
||||
auto spatial = (int64_t)((bool)src->Attributes()[L"spatial"].Value<bool>() ? 1 : 0);
|
||||
|
|
|
@ -1911,17 +1911,27 @@ std::pair<std::vector<size_t>, std::vector<size_t>> ONNXToCNTKHelper::AdjustONNX
|
|||
return SplitAndReverseVec(pads);
|
||||
}
|
||||
|
||||
CNTK::DataType ConvertDataTypeTensorProtoToCNTK(TensorProto_DataType newDataType)
|
||||
{
|
||||
// to TensorProto_DataType
|
||||
switch (newDataType)
|
||||
{
|
||||
case TensorProto_DataType::TensorProto_DataType_FLOAT:
|
||||
return CNTK::DataType::Float;
|
||||
case TensorProto_DataType::TensorProto_DataType_DOUBLE:
|
||||
return CNTK::DataType::Double;
|
||||
case TensorProto_DataType::TensorProto_DataType_FLOAT16:
|
||||
return CNTK::DataType::Float16;
|
||||
default:
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
}
|
||||
|
||||
FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector<Variable> &inputs, const Graph *graph)
|
||||
{
|
||||
string onnxOpName = node->OpType();
|
||||
|
||||
if (onnxOpName == "Cast" && inputs[0].GetDataType() == CNTK::DataType::Float && inputs[0].Owner() != nullptr)
|
||||
{
|
||||
// CNTK does not support cast op. Only float is available with ONNX support.
|
||||
// Question for having a cast op: Why not cast data as necessary internally.
|
||||
return inputs[0].Owner();
|
||||
}
|
||||
else if (onnxOpName == "LSTM")
|
||||
if (onnxOpName == "LSTM")
|
||||
{
|
||||
const string direction = GetNamedAttributeAsString(node, "direction");
|
||||
std::vector<float> activation_alpha = GetNamedAttributeAsFloatVec(node, "activation_alpha", std::vector<float>());
|
||||
|
@ -2721,6 +2731,13 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
|
|||
FunctionPtr cntkFunction = Acos(inputs[0], ToWString(node->Name()));
|
||||
return cntkFunction;
|
||||
}
|
||||
else if (onnxOpName == "Cast")
|
||||
{
|
||||
TensorProto_DataType newDataType = static_cast<TensorProto_DataType>(GetNamedAttributeAsInt64(node, "to"));
|
||||
DataType cntkNewDataType = ConvertDataTypeTensorProtoToCNTK(newDataType);
|
||||
FunctionPtr cntkFunction = Cast(inputs[0], cntkNewDataType, ToWString(node->Name()));
|
||||
return cntkFunction;
|
||||
}
|
||||
else if (onnxOpName == "Tan")
|
||||
{
|
||||
FunctionPtr cntkFunction = Tan(inputs[0], ToWString(node->Name()));
|
||||
|
|
|
@ -352,7 +352,10 @@ namespace ONNX
|
|||
} } },
|
||||
|
||||
// From tensor
|
||||
// { L"", "Cast" },
|
||||
{ L"Cast", { {
|
||||
{ L"Cast", "Cast" },
|
||||
{ L"newDataType", "to" },
|
||||
} } },
|
||||
{ L"Splice", { {
|
||||
{ L"Splice", "Concat" },
|
||||
{ L"axis", "axis" },
|
||||
|
|
|
@ -30,7 +30,7 @@ inline Status FileOpenRd(const std::wstring& path, /*out*/ int* p_fd) {
|
|||
}
|
||||
|
||||
inline Status FileOpenWr(const std::wstring& path, /*out*/ int* p_fd) {
|
||||
_wsopen_s(p_fd, path.c_str(), _O_CREAT | _O_SEQUENTIAL | _O_BINARY | _O_WRONLY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
|
||||
_wsopen_s(p_fd, path.c_str(), _O_CREAT | O_TRUNC | _O_SEQUENTIAL | _O_BINARY | _O_WRONLY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
|
||||
if (0 > *p_fd) {
|
||||
return Status(SYSTEM, errno);
|
||||
}
|
||||
|
@ -52,7 +52,7 @@ inline Status FileOpenRd(const std::string& path, /*out*/ int* p_fd) {
|
|||
|
||||
inline Status FileOpenWr(const std::string& path, /*out*/ int* p_fd) {
|
||||
#ifdef _WIN32
|
||||
_sopen_s(p_fd, path.c_str(), _O_CREAT | _O_SEQUENTIAL | _O_BINARY | _O_WRONLY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
|
||||
_sopen_s(p_fd, path.c_str(), _O_CREAT | O_TRUNC | _O_SEQUENTIAL | _O_BINARY | _O_WRONLY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
|
||||
#else
|
||||
*p_fd = open(path.c_str(), O_WRONLY | O_CREAT | O_TRUNC, 0644);
|
||||
#endif
|
||||
|
|
|
@ -234,6 +234,18 @@ def test_BatchNormalization(tmpdir, dtype):
|
|||
|
||||
verify_one_input(op_node, t, tmpdir, 'BatchNormalization')
|
||||
|
||||
#Cast
|
||||
Cast_Type_Config = (np.float64, np.float32, np.float16)
|
||||
@pytest.mark.parametrize("from_type", Cast_Type_Config)
|
||||
@pytest.mark.parametrize("to_type", Cast_Type_Config)
|
||||
def test_Cast(tmpdir, from_type, to_type):
|
||||
test_name = "cast_" + from_type.__name__ + "_to_" + to_type.__name__
|
||||
shape = (3, 10, 15)
|
||||
input_var = C.input_variable(shape, dtype = from_type, name='features')
|
||||
model = C.cast(input_var, dtype=to_type)
|
||||
data = np.random.rand(*shape).astype(from_type)
|
||||
verify_one_input(model, data, tmpdir, test_name)
|
||||
|
||||
# Ceil
|
||||
@pytest.mark.parametrize("dtype", DType_Config)
|
||||
def test_Ceil(tmpdir, dtype):
|
||||
|
|
Загрузка…
Ссылка в новой задаче