This commit is contained in:
liqfu 2018-07-02 08:38:50 -07:00
Родитель bece037f63
Коммит 07fd96f357
5 изменённых файлов: 67 добавлений и 19 удалений

Просмотреть файл

@ -195,6 +195,11 @@ private:
static std::vector<float> INTSToVecFloat(const std::vector<int64_t> &ints);
static std::vector<int64_t> ConvertPermutationCNTKToONNX(const std::vector<Axis> &axes, bool hasBatchAxis);
//
// Convert DataType from CNTK TensorProto
//
static TensorProto_DataType ConvertDataTypeCNTKToTensorProto(CNTK::DataType newDataType);
//
// Convert data types from CNTK to ONNX.
//
@ -1190,24 +1195,29 @@ void MapAndUpdateONNXType(const std::string &op, bool inputArg, int argOrder, CN
type.mutable_tensor_type()->set_elem_type(onnx::TensorProto_DataType_FLOAT);
}
void CNTKToONNXHelper::UpdateONNXType(CNTK::DataType dataType, onnx::TypeProto &type)
TensorProto_DataType CNTKToONNXHelper::ConvertDataTypeCNTKToTensorProto(
CNTK::DataType newDataType)
{
switch (dataType)
// to TensorProto_DataType
switch (newDataType)
{
case CNTK::DataType::Float:
type.mutable_tensor_type()->set_elem_type(onnx::TensorProto_DataType_FLOAT);
break;
case CNTK::DataType::Float16:
type.mutable_tensor_type()->set_elem_type(onnx::TensorProto_DataType_FLOAT16);
break;
return TensorProto_DataType::TensorProto_DataType_FLOAT;
case CNTK::DataType::Double:
type.mutable_tensor_type()->set_elem_type(onnx::TensorProto_DataType_DOUBLE);
break;
return TensorProto_DataType::TensorProto_DataType_DOUBLE;
case CNTK::DataType::Float16:
return TensorProto_DataType::TensorProto_DataType_FLOAT16;
default:
NOT_IMPLEMENTED;
}
}
void CNTKToONNXHelper::UpdateONNXType(CNTK::DataType dataType, onnx::TypeProto &type)
{
TensorProto_DataType tensorProtoDataType = ConvertDataTypeCNTKToTensorProto(dataType);
type.mutable_tensor_type()->set_elem_type(tensorProtoDataType);
}
std::string CNTKToONNXHelper::ToOPName(const FunctionPtr& src)
{
auto lookup = Operators::CntkToONNXLookup();
@ -2847,6 +2857,12 @@ void CNTKToONNXHelper::CopyAttributes(const FunctionPtr& src, LotusIR::Node* nod
node->AddAttribute("min", minValue);
node->AddAttribute("max", maxValue);
}
else if (src->OpName() == L"Cast")
{
DataType newDataType = static_cast<DataType>(src->Attributes()[L"newDataType"].Value<int>());
int64_t to = static_cast<int64_t>(ConvertDataTypeCNTKToTensorProto(newDataType));
node->AddAttribute(attributesMap[L"newDataType"], to);
}
if (src->OpName() == L"BatchNormalization")
{
auto spatial = (int64_t)((bool)src->Attributes()[L"spatial"].Value<bool>() ? 1 : 0);

Просмотреть файл

@ -1911,17 +1911,27 @@ std::pair<std::vector<size_t>, std::vector<size_t>> ONNXToCNTKHelper::AdjustONNX
return SplitAndReverseVec(pads);
}
CNTK::DataType ConvertDataTypeTensorProtoToCNTK(TensorProto_DataType newDataType)
{
// to TensorProto_DataType
switch (newDataType)
{
case TensorProto_DataType::TensorProto_DataType_FLOAT:
return CNTK::DataType::Float;
case TensorProto_DataType::TensorProto_DataType_DOUBLE:
return CNTK::DataType::Double;
case TensorProto_DataType::TensorProto_DataType_FLOAT16:
return CNTK::DataType::Float16;
default:
NOT_IMPLEMENTED;
}
}
FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector<Variable> &inputs, const Graph *graph)
{
string onnxOpName = node->OpType();
if (onnxOpName == "Cast" && inputs[0].GetDataType() == CNTK::DataType::Float && inputs[0].Owner() != nullptr)
{
// CNTK does not support cast op. Only float is available with ONNX support.
// Question for having a cast op: Why not cast data as necessary internally.
return inputs[0].Owner();
}
else if (onnxOpName == "LSTM")
if (onnxOpName == "LSTM")
{
const string direction = GetNamedAttributeAsString(node, "direction");
std::vector<float> activation_alpha = GetNamedAttributeAsFloatVec(node, "activation_alpha", std::vector<float>());
@ -2721,6 +2731,13 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
FunctionPtr cntkFunction = Acos(inputs[0], ToWString(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Cast")
{
TensorProto_DataType newDataType = static_cast<TensorProto_DataType>(GetNamedAttributeAsInt64(node, "to"));
DataType cntkNewDataType = ConvertDataTypeTensorProtoToCNTK(newDataType);
FunctionPtr cntkFunction = Cast(inputs[0], cntkNewDataType, ToWString(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Tan")
{
FunctionPtr cntkFunction = Tan(inputs[0], ToWString(node->Name()));

Просмотреть файл

@ -352,7 +352,10 @@ namespace ONNX
} } },
// From tensor
// { L"", "Cast" },
{ L"Cast", { {
{ L"Cast", "Cast" },
{ L"newDataType", "to" },
} } },
{ L"Splice", { {
{ L"Splice", "Concat" },
{ L"axis", "axis" },

Просмотреть файл

@ -30,7 +30,7 @@ inline Status FileOpenRd(const std::wstring& path, /*out*/ int* p_fd) {
}
inline Status FileOpenWr(const std::wstring& path, /*out*/ int* p_fd) {
_wsopen_s(p_fd, path.c_str(), _O_CREAT | _O_SEQUENTIAL | _O_BINARY | _O_WRONLY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
_wsopen_s(p_fd, path.c_str(), _O_CREAT | O_TRUNC | _O_SEQUENTIAL | _O_BINARY | _O_WRONLY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
if (0 > *p_fd) {
return Status(SYSTEM, errno);
}
@ -52,7 +52,7 @@ inline Status FileOpenRd(const std::string& path, /*out*/ int* p_fd) {
inline Status FileOpenWr(const std::string& path, /*out*/ int* p_fd) {
#ifdef _WIN32
_sopen_s(p_fd, path.c_str(), _O_CREAT | _O_SEQUENTIAL | _O_BINARY | _O_WRONLY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
_sopen_s(p_fd, path.c_str(), _O_CREAT | O_TRUNC | _O_SEQUENTIAL | _O_BINARY | _O_WRONLY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
#else
*p_fd = open(path.c_str(), O_WRONLY | O_CREAT | O_TRUNC, 0644);
#endif

Просмотреть файл

@ -234,6 +234,18 @@ def test_BatchNormalization(tmpdir, dtype):
verify_one_input(op_node, t, tmpdir, 'BatchNormalization')
#Cast
Cast_Type_Config = (np.float64, np.float32, np.float16)
@pytest.mark.parametrize("from_type", Cast_Type_Config)
@pytest.mark.parametrize("to_type", Cast_Type_Config)
def test_Cast(tmpdir, from_type, to_type):
test_name = "cast_" + from_type.__name__ + "_to_" + to_type.__name__
shape = (3, 10, 15)
input_var = C.input_variable(shape, dtype = from_type, name='features')
model = C.cast(input_var, dtype=to_type)
data = np.random.rand(*shape).astype(from_type)
verify_one_input(model, data, tmpdir, test_name)
# Ceil
@pytest.mark.parametrize("dtype", DType_Config)
def test_Ceil(tmpdir, dtype):