enable loading onnx from buffer

This commit is contained in:
Yang Chen 2019-07-16 10:08:27 -07:00
Родитель 9fe872b413
Коммит f310127f72
4 изменённых файлов: 48 добавлений и 9 удалений

Просмотреть файл

@ -3613,7 +3613,8 @@ namespace CNTK
/// Load a Function from a memory buffer
///
CNTK_API static FunctionPtr Load(const char* buffer, size_t length,
const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice(),
ModelFormat format = ModelFormat::CNTKv2);
///
/// Load a Function from an istream. The legacy V1 model is not supported.

Просмотреть файл

@ -531,7 +531,7 @@ namespace CNTK
return nullptr;
}
/*static*/ FunctionPtr Function::Load(const char *buffer, size_t length, const DeviceDescriptor& computeDevice)
/*static*/ FunctionPtr Function::Load(const char *buffer, size_t length, const DeviceDescriptor& computeDevice, ModelFormat format)
{
if ((buffer == nullptr) || (length <= 0))
InvalidArgument("The model buffer should not be null and its length should be greater than 0");
@ -545,15 +545,32 @@ namespace CNTK
}
};
if (Internal::IsLegacyModel(buffer, length))
InvalidArgument("Loading a legacy model from byte array is not supported.");
else
switch (format)
{
modelStreamBuffer buf(buffer, length);
std::istream modelStream(&buf);
case ModelFormat::CNTKv2:
{
if (Internal::IsLegacyModel(buffer, length)) {
InvalidArgument("Loading a legacy model from byte array is not supported.");
}
else
{
modelStreamBuffer buf(buffer, length);
std::istream modelStream(&buf);
return Load(modelStream, computeDevice);
return Load(modelStream, computeDevice);
}
break;
}
case ModelFormat::ONNX:
return ONNXFormat::Load(static_cast<const void*>(buffer), length, computeDevice);
break;
default:
InvalidArgument("unsupported ModelFormat.");
}
return nullptr;
}
/*static*/ FunctionPtr Function::Load(std::istream& inputStream, const DeviceDescriptor& computeDevice)

Просмотреть файл

@ -121,3 +121,23 @@ FunctionPtr ONNXFormat::Load(const std::wstring& filepath, const DeviceDescripto
FunctionPtr cntkFunction = ONNXToCNTK::CreateGraph(&model->MainGraph(), computeDevice);
return cntkFunction;
}
FunctionPtr ONNXFormat::Load(const void* model_data, int model_data_len, const DeviceDescriptor& computeDevice)
{
InitializeLotusIR();
onnx::ModelProto model_proto;
const bool result = model_proto.ParseFromArray(model_data, model_data_len);
if (!result) {
LogicError("protobuf failed to parse model");
}
std::shared_ptr<onnxruntime::Model> model;
onnxruntime::common::Status loadStatus = onnxruntime::Model::Load(model_proto, model);
if (!loadStatus.IsOK())
LogicError("Failed to load model: '%s'", loadStatus.ErrorMessage().c_str());
FunctionPtr cntkFunction = ONNXToCNTK::CreateGraph(&model->MainGraph(), computeDevice);
return cntkFunction;
}

Просмотреть файл

@ -16,8 +16,9 @@ namespace CNTK
public:
static void Save(const FunctionPtr& src, const std::wstring& filepath);
static FunctionPtr Load(const std::wstring& filepath, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
static FunctionPtr Load(const void* model_data, int model_data_len, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
private:
static void InitializeLotusIR();
static std::once_flag op_schema_initializer_flag_;
};
}
}