Added an API to load model from memory
This commit is contained in:
Родитель
238eb2d055
Коммит
9f7dd2b63e
|
@ -109,6 +109,22 @@ CNTK_API CNTK_StatusCode CNTK_LoadModel(
|
|||
/*[in]*/ const CNTK_DeviceDescriptor* device,
|
||||
/*[out]*/ CNTK_ModelHandle* model);
|
||||
|
||||
//
|
||||
// Loads a model from the specified buffer and returns an opaque handle to the model
|
||||
// that should be passed to further operations.
|
||||
//
|
||||
// Parameters:
|
||||
// modelData [in]: a buffer that holds the CNTK model
|
||||
// modelDataLen [in]: the length of the buffer
|
||||
// device [in]: device descriptor.
|
||||
// model [out]: the resulting loaded model
|
||||
//
|
||||
CNTK_StatusCode CNTKModelCompiler_LoadModel_FromArray(
|
||||
/*[in]*/ const void* modelData,
|
||||
/*[in]*/ int modelDataLen,
|
||||
/*[in]*/ const CNTK_DeviceDescriptor* device,
|
||||
/*[out]*/ CNTK_ModelHandle* model);
|
||||
|
||||
enum CNTK_ParameterCloningMethod
|
||||
{
|
||||
///
|
||||
|
|
|
@ -152,6 +152,8 @@ namespace CNTK
|
|||
public:
|
||||
CNTKEvaluatorWrapper(const char* modelFilePath, const CNTK_DeviceDescriptor* device);
|
||||
CNTKEvaluatorWrapper(const char* modelFilePath, DeviceDescriptor device);
|
||||
CNTKEvaluatorWrapper(const void* modelData, int modelDataLen, const CNTK_DeviceDescriptor* device);
|
||||
CNTKEvaluatorWrapper(const void* modelData, int modelDataLen, DeviceDescriptor device);
|
||||
CNTKEvaluatorWrapper(FunctionPtr model, DeviceDescriptor device);
|
||||
|
||||
void GetModelArgumentsInfo(CNTK_Variable** inputs, uint32_t* numInputs) override;
|
||||
|
|
|
@ -104,6 +104,22 @@ CNTK_StatusCode CNTK_LoadModel(const char* modelFilePath, const CNTK_DeviceDescr
|
|||
return ExceptionCatcher::Call([&]() { *handle = new CNTKEvaluatorWrapper(modelFilePath, device); });
|
||||
}
|
||||
|
||||
CNTK_StatusCode CNTKModelCompiler_LoadModel_FromArray(const void* modelData, int modelDataLen,
|
||||
const CNTK_DeviceDescriptor* device, CNTK_ModelHandle* handle)
|
||||
{
|
||||
if (!handle)
|
||||
return StatusCode(CNTK_ERROR_NULL_POINTER, "'handle' parameter is not allowed to be null");
|
||||
|
||||
if (!modelData)
|
||||
return StatusCode(CNTK_ERROR_NULL_POINTER, "'modelData' parameter is not allowed to be null");
|
||||
|
||||
if (modelDataLen <= 0)
|
||||
return StatusCode(CNTK_ERROR_INVALID_INPUT, "'modelDataLen' parameter must be greater than zero");
|
||||
|
||||
*handle = nullptr;
|
||||
return ExceptionCatcher::Call([&]() { *handle = new CNTKEvaluatorWrapper(modelData, modelDataLen, device); });
|
||||
}
|
||||
|
||||
CNTK_StatusCode CNTK_CloneModel(CNTK_ModelHandle model, CNTK_ParameterCloningMethod method, bool flatten, CNTK_ModelHandle* cloned)
|
||||
{
|
||||
if (model == CNTK_INVALID_MODEL_HANDLE)
|
||||
|
|
|
@ -32,6 +32,14 @@ namespace CNTK
|
|||
CNTKEvaluatorWrapper(modelFilePath, GetDeviceDescriptor(device))
|
||||
{}
|
||||
|
||||
CNTKEvaluatorWrapper::CNTKEvaluatorWrapper(const void* modelData, int modelDataLen, DeviceDescriptor device) :
|
||||
CNTKEvaluatorWrapper(Function::Load(static_cast<const char*>(modelData), modelDataLen, device), device)
|
||||
{}
|
||||
|
||||
CNTKEvaluatorWrapper::CNTKEvaluatorWrapper(const void* modelData, int modelDataLen, const CNTK_DeviceDescriptor* device) :
|
||||
CNTKEvaluatorWrapper(modelData, modelDataLen, GetDeviceDescriptor(device))
|
||||
{}
|
||||
|
||||
void CNTKEvaluatorWrapper::GetModelArgumentsInfo(CNTK_Variable** inputs, uint32_t* numInputs)
|
||||
{
|
||||
assert(inputs != nullptr);
|
||||
|
|
Загрузка…
Ссылка в новой задаче