Added an API to load model from memory

This commit is contained in:
Yang Chen 2019-06-19 17:33:07 -07:00
Родитель 238eb2d055
Коммит 9f7dd2b63e
4 изменённых файлов: 42 добавлений и 0 удалений

Просмотреть файл

@ -109,6 +109,22 @@ CNTK_API CNTK_StatusCode CNTK_LoadModel(
/*[in]*/ const CNTK_DeviceDescriptor* device,
/*[out]*/ CNTK_ModelHandle* model);
//
// Loads a model from the specified buffer and returns an opaque handle to the model
// that should be passed to further operations.
//
// Parameters:
// modelData [in]: a buffer that holds the CNTK model
// modelDataLen [in]: the length of the buffer
// device [in]: device descriptor.
// model [out]: the resulting loaded model
//
CNTK_StatusCode CNTKModelCompiler_LoadModel_FromArray(
/*[in]*/ const void* modelData,
/*[in]*/ int modelDataLen,
/*[in]*/ const CNTK_DeviceDescriptor* device,
/*[out]*/ CNTK_ModelHandle* model);
enum CNTK_ParameterCloningMethod
{
///

Просмотреть файл

@ -152,6 +152,8 @@ namespace CNTK
public:
CNTKEvaluatorWrapper(const char* modelFilePath, const CNTK_DeviceDescriptor* device);
CNTKEvaluatorWrapper(const char* modelFilePath, DeviceDescriptor device);
CNTKEvaluatorWrapper(const void* modelData, int modelDataLen, const CNTK_DeviceDescriptor* device);
CNTKEvaluatorWrapper(const void* modelData, int modelDataLen, DeviceDescriptor device);
CNTKEvaluatorWrapper(FunctionPtr model, DeviceDescriptor device);
void GetModelArgumentsInfo(CNTK_Variable** inputs, uint32_t* numInputs) override;

Просмотреть файл

@ -104,6 +104,22 @@ CNTK_StatusCode CNTK_LoadModel(const char* modelFilePath, const CNTK_DeviceDescr
return ExceptionCatcher::Call([&]() { *handle = new CNTKEvaluatorWrapper(modelFilePath, device); });
}
CNTK_StatusCode CNTKModelCompiler_LoadModel_FromArray(const void* modelData, int modelDataLen,
const CNTK_DeviceDescriptor* device, CNTK_ModelHandle* handle)
{
if (!handle)
return StatusCode(CNTK_ERROR_NULL_POINTER, "'handle' parameter is not allowed to be null");
if (!modelData)
return StatusCode(CNTK_ERROR_NULL_POINTER, "'modelData' parameter is not allowed to be null");
if (modelDataLen <= 0)
return StatusCode(CNTK_ERROR_INVALID_INPUT, "'modelDataLen' parameter must be greater than zero");
*handle = nullptr;
return ExceptionCatcher::Call([&]() { *handle = new CNTKEvaluatorWrapper(modelData, modelDataLen, device); });
}
CNTK_StatusCode CNTK_CloneModel(CNTK_ModelHandle model, CNTK_ParameterCloningMethod method, bool flatten, CNTK_ModelHandle* cloned)
{
if (model == CNTK_INVALID_MODEL_HANDLE)

Просмотреть файл

@ -32,6 +32,14 @@ namespace CNTK
CNTKEvaluatorWrapper(modelFilePath, GetDeviceDescriptor(device))
{}
CNTKEvaluatorWrapper::CNTKEvaluatorWrapper(const void* modelData, int modelDataLen, DeviceDescriptor device) :
CNTKEvaluatorWrapper(Function::Load(static_cast<const char*>(modelData), modelDataLen, device), device)
{}
CNTKEvaluatorWrapper::CNTKEvaluatorWrapper(const void* modelData, int modelDataLen, const CNTK_DeviceDescriptor* device) :
CNTKEvaluatorWrapper(modelData, modelDataLen, GetDeviceDescriptor(device))
{}
void CNTKEvaluatorWrapper::GetModelArgumentsInfo(CNTK_Variable** inputs, uint32_t* numInputs)
{
assert(inputs != nullptr);