enable loading onnx from buffer
This commit is contained in:
Родитель
9fe872b413
Коммит
f310127f72
|
@ -3613,7 +3613,8 @@ namespace CNTK
|
|||
/// Load a Function from a memory buffer
|
||||
///
|
||||
CNTK_API static FunctionPtr Load(const char* buffer, size_t length,
|
||||
const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
|
||||
const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice(),
|
||||
ModelFormat format = ModelFormat::CNTKv2);
|
||||
|
||||
///
|
||||
/// Load a Function from an istream. The legacy V1 model is not supported.
|
||||
|
|
|
@ -531,7 +531,7 @@ namespace CNTK
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
/*static*/ FunctionPtr Function::Load(const char *buffer, size_t length, const DeviceDescriptor& computeDevice)
|
||||
/*static*/ FunctionPtr Function::Load(const char *buffer, size_t length, const DeviceDescriptor& computeDevice, ModelFormat format)
|
||||
{
|
||||
if ((buffer == nullptr) || (length <= 0))
|
||||
InvalidArgument("The model buffer should not be null and its length should be greater than 0");
|
||||
|
@ -545,15 +545,32 @@ namespace CNTK
|
|||
}
|
||||
};
|
||||
|
||||
if (Internal::IsLegacyModel(buffer, length))
|
||||
InvalidArgument("Loading a legacy model from byte array is not supported.");
|
||||
else
|
||||
switch (format)
|
||||
{
|
||||
modelStreamBuffer buf(buffer, length);
|
||||
std::istream modelStream(&buf);
|
||||
case ModelFormat::CNTKv2:
|
||||
{
|
||||
if (Internal::IsLegacyModel(buffer, length)) {
|
||||
InvalidArgument("Loading a legacy model from byte array is not supported.");
|
||||
}
|
||||
else
|
||||
{
|
||||
modelStreamBuffer buf(buffer, length);
|
||||
std::istream modelStream(&buf);
|
||||
|
||||
return Load(modelStream, computeDevice);
|
||||
return Load(modelStream, computeDevice);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case ModelFormat::ONNX:
|
||||
return ONNXFormat::Load(static_cast<const void*>(buffer), length, computeDevice);
|
||||
break;
|
||||
|
||||
default:
|
||||
InvalidArgument("unsupported ModelFormat.");
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/*static*/ FunctionPtr Function::Load(std::istream& inputStream, const DeviceDescriptor& computeDevice)
|
||||
|
|
|
@ -121,3 +121,23 @@ FunctionPtr ONNXFormat::Load(const std::wstring& filepath, const DeviceDescripto
|
|||
FunctionPtr cntkFunction = ONNXToCNTK::CreateGraph(&model->MainGraph(), computeDevice);
|
||||
return cntkFunction;
|
||||
}
|
||||
|
||||
FunctionPtr ONNXFormat::Load(const void* model_data, int model_data_len, const DeviceDescriptor& computeDevice)
|
||||
{
|
||||
InitializeLotusIR();
|
||||
|
||||
onnx::ModelProto model_proto;
|
||||
const bool result = model_proto.ParseFromArray(model_data, model_data_len);
|
||||
if (!result) {
|
||||
LogicError("protobuf failed to parse model");
|
||||
}
|
||||
|
||||
std::shared_ptr<onnxruntime::Model> model;
|
||||
onnxruntime::common::Status loadStatus = onnxruntime::Model::Load(model_proto, model);
|
||||
|
||||
if (!loadStatus.IsOK())
|
||||
LogicError("Failed to load model: '%s'", loadStatus.ErrorMessage().c_str());
|
||||
|
||||
FunctionPtr cntkFunction = ONNXToCNTK::CreateGraph(&model->MainGraph(), computeDevice);
|
||||
return cntkFunction;
|
||||
}
|
||||
|
|
|
@ -16,8 +16,9 @@ namespace CNTK
|
|||
public:
|
||||
static void Save(const FunctionPtr& src, const std::wstring& filepath);
|
||||
static FunctionPtr Load(const std::wstring& filepath, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
|
||||
static FunctionPtr Load(const void* model_data, int model_data_len, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
|
||||
private:
|
||||
static void InitializeLotusIR();
|
||||
static std::once_flag op_schema_initializer_flag_;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче