Improve UDF serialization
* Add a base class for User-Defined Functions * Add support for native udf serialization * Change API to accept std::function callbacks
This commit is contained in:
Родитель
6c9d42f762
Коммит
5bc3661988
|
@ -16,41 +16,71 @@ void UserTimesFunctionExample()
|
|||
auto x = InputVariable(NDShape({ inDim }), DataType::Float, { Axis::DefaultBatchAxis() });
|
||||
auto userDefinedTimes = UserTimesFunction::Create(W, x, L"UserDefinedTimes");
|
||||
|
||||
size_t batchSize = 3;
|
||||
std::vector<float> inputData(inDim * batchSize);
|
||||
for (size_t i = 0; i < inputData.size(); ++i)
|
||||
inputData[i] = (float)rand() / RAND_MAX;
|
||||
|
||||
auto inputDataValue = Value::CreateBatch(x.Shape(), inputData, device);
|
||||
auto compareWithBuiltInTimes = [device, outDim, inDim](FunctionPtr times) {
|
||||
size_t batchSize = 3;
|
||||
std::vector<float> inputData(inDim * batchSize);
|
||||
for (size_t i = 0; i < inputData.size(); ++i)
|
||||
inputData[i] = (float)rand() / RAND_MAX;
|
||||
|
||||
std::vector<float> rootGradientData(outDim * batchSize, 1);
|
||||
auto rootGradientValue = Value::CreateBatch(userDefinedTimes->Output().Shape(), rootGradientData, device);
|
||||
auto input = times->Arguments()[0];
|
||||
auto inputDataValue = Value::CreateBatch(input.Shape(), inputData, device);
|
||||
|
||||
std::unordered_map<Variable, ValuePtr> outputValues = { { userDefinedTimes->Output(), nullptr } };
|
||||
auto backPropState = userDefinedTimes->Forward({ { x, inputDataValue } }, outputValues, device, { userDefinedTimes->Output() });
|
||||
std::vector<float> rootGradientData(outDim * batchSize, 1);
|
||||
auto rootGradientValue = Value::CreateBatch(times->Output().Shape(), rootGradientData, device);
|
||||
|
||||
std::unordered_map<Variable, ValuePtr> inputGradientValues = { { W, nullptr } };
|
||||
userDefinedTimes->Backward(backPropState, { { userDefinedTimes->Output(), rootGradientValue } }, inputGradientValues);
|
||||
auto userDefinedTimesOutputValue = outputValues[userDefinedTimes->Output()];
|
||||
auto userDefinedTimesInputGradientValue = inputGradientValues[W];
|
||||
std::unordered_map<Variable, ValuePtr> outputValues = { { times->Output(), nullptr } };
|
||||
auto backPropState = times->Forward({ { input, inputDataValue } }, outputValues, device, { times->Output() });
|
||||
|
||||
// Compare against the CNTK built-in implementation
|
||||
auto builtInTimes = Times(W, x, L"BuiltInTimes");
|
||||
outputValues = { { builtInTimes->Output(), nullptr } };
|
||||
backPropState = builtInTimes->Forward({ { x, inputDataValue } }, outputValues, device, { builtInTimes->Output() });
|
||||
inputGradientValues = { { W, nullptr } };
|
||||
builtInTimes->Backward(backPropState, { { builtInTimes->Output(), rootGradientValue } }, inputGradientValues);
|
||||
auto builtInTimesOutputValue = outputValues[builtInTimes->Output()];
|
||||
auto builtInTimesInputGradientValue = inputGradientValues[W];
|
||||
|
||||
const double relativeTolerance = 0.001f;
|
||||
const double absoluteTolerance = 0.000001f;
|
||||
auto parameter = times->Parameters()[0];
|
||||
|
||||
if (!Internal::AreEqual(*userDefinedTimesOutputValue, *builtInTimesOutputValue, relativeTolerance, absoluteTolerance))
|
||||
std::runtime_error("UserTimesOp's Forward result does not match built-in result");
|
||||
std::unordered_map<Variable, ValuePtr> inputGradientValues = { { parameter, nullptr } };
|
||||
times->Backward(backPropState, { { times->Output(), rootGradientValue } }, inputGradientValues);
|
||||
auto userDefinedTimesOutputValue = outputValues[times->Output()];
|
||||
auto userDefinedTimesInputGradientValue = inputGradientValues[parameter];
|
||||
|
||||
if (!Internal::AreEqual(*userDefinedTimesInputGradientValue, *builtInTimesInputGradientValue, relativeTolerance, absoluteTolerance))
|
||||
std::runtime_error("UserTimesOp's Forward result does not match built-in result");
|
||||
// Compare against the CNTK built-in implementation
|
||||
auto builtInTimes = Times(parameter, input, L"BuiltInTimes");
|
||||
outputValues = { { builtInTimes->Output(), nullptr } };
|
||||
backPropState = builtInTimes->Forward({ { input, inputDataValue } }, outputValues, device, { builtInTimes->Output() });
|
||||
inputGradientValues = { { parameter, nullptr } };
|
||||
builtInTimes->Backward(backPropState, { { builtInTimes->Output(), rootGradientValue } }, inputGradientValues);
|
||||
auto builtInTimesOutputValue = outputValues[builtInTimes->Output()];
|
||||
auto builtInTimesInputGradientValue = inputGradientValues[parameter];
|
||||
|
||||
const double relativeTolerance = 0.001f;
|
||||
const double absoluteTolerance = 0.000001f;
|
||||
|
||||
if (!Internal::AreEqual(*userDefinedTimesOutputValue, *builtInTimesOutputValue, relativeTolerance, absoluteTolerance))
|
||||
std::runtime_error("UserTimesOp's Forward result does not match built-in result");
|
||||
|
||||
if (!Internal::AreEqual(*userDefinedTimesInputGradientValue, *builtInTimesInputGradientValue, relativeTolerance, absoluteTolerance))
|
||||
std::runtime_error("UserTimesOp's Forward result does not match built-in result");
|
||||
|
||||
};
|
||||
|
||||
compareWithBuiltInTimes(userDefinedTimes);
|
||||
|
||||
auto version = std::string(CNTK_COMPONENT_VERSION);
|
||||
std::wstring wversion(version.begin(), version.end());
|
||||
Function::RegisterNativeUserFunction(L"NativeUserTimesOp", L"Cntk.ExtensibilityExamples-" + wversion, L"CreateUserTimesFunction");
|
||||
|
||||
userDefinedTimes->Save(L"UserTimesFunctionExample.model");
|
||||
|
||||
auto userDefinedTimes_reloaded_1 = Function::Load(L"UserTimesFunctionExample.model", device);
|
||||
|
||||
compareWithBuiltInTimes(userDefinedTimes_reloaded_1);
|
||||
|
||||
Function::RegisterUDFDeserializeCallback(L"NativeUserTimesOp", [](const std::vector<Variable>& inputs,
|
||||
const std::wstring& name,
|
||||
const Dictionary& state) {
|
||||
return UserTimesFunction::Create(inputs[0], inputs[1], state, name);
|
||||
});
|
||||
|
||||
auto userDefinedTimes_reloaded_2 = Function::Load(L"UserTimesFunctionExample.model", device);
|
||||
|
||||
compareWithBuiltInTimes(userDefinedTimes_reloaded_2);
|
||||
}
|
||||
#pragma warning(pop)
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ public:
|
|||
}
|
||||
|
||||
UserTimesFunction(const Variable& leftOperand, const Variable& rightOperand, const Dictionary& attributes, const std::wstring& name)
|
||||
: Function({ leftOperand, rightOperand }, Dictionary(attributes), name)
|
||||
: Function({ leftOperand, rightOperand }, attributes, name)
|
||||
{}
|
||||
|
||||
private:
|
||||
|
@ -116,11 +116,10 @@ private:
|
|||
|
||||
const std::wstring& OpName() const override
|
||||
{
|
||||
static const std::wstring opName = L"UserTimesOp";
|
||||
static const std::wstring opName = L"NativeUserTimesOp";
|
||||
return opName;
|
||||
}
|
||||
|
||||
Dictionary Serialize() const override { NOT_IMPLEMENTED; }
|
||||
size_t CurrentVersion() const override { NOT_IMPLEMENTED; }
|
||||
|
||||
void InferOutputs(std::vector<Variable>& outputs) override
|
||||
|
|
1
Makefile
1
Makefile
|
@ -469,6 +469,7 @@ CNTKLIBRARY_COMMON_SRC =\
|
|||
$(SOURCEDIR)/CNTKv2LibraryDll/Function.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/PrimitiveFunction.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/CompositeFunction.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/UserDefinedFunction.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/NDArrayView.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/NDMask.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/Trainer.cpp \
|
||||
|
|
|
@ -2885,6 +2885,8 @@ namespace CNTK
|
|||
const std::wstring& /*name*/,
|
||||
const Dictionary& /*dictionary*/)> UDFDeserializeCallback;
|
||||
|
||||
typedef std::shared_ptr<UDFDeserializeCallback> UDFDeserializeCallbackPtr;
|
||||
|
||||
static auto NoOp = [] (const std::vector<Variable>&, const std::wstring&, const Dictionary&)
|
||||
{
|
||||
return nullptr;
|
||||
|
@ -2907,6 +2909,7 @@ namespace CNTK
|
|||
friend class Trainer;
|
||||
|
||||
friend Variable GetCorrespondingOutputVariableFromClone(const Variable&, const FunctionPtr&, const FunctionPtr&);
|
||||
friend bool Internal::IsNativeUserFunctionRegistered(const std::wstring& uniqueOpName);
|
||||
|
||||
public:
|
||||
|
||||
|
@ -2975,6 +2978,20 @@ namespace CNTK
|
|||
///
|
||||
CNTK_API virtual void InferOutputs(std::vector<Variable>& outputs) = 0;
|
||||
|
||||
///
|
||||
/// Returns the name of the module (dll/so) containing this function. For native functions registered through
|
||||
/// a call to 'RegisterNativeUserFunction', unless overridden, this method return the value of the 'moduleName'
|
||||
/// argument.
|
||||
///
|
||||
CNTK_API virtual std::wstring ModuleName() const;
|
||||
|
||||
///
|
||||
/// Returns the name of the method which should be invoked to deserialize this function. For native functions
|
||||
/// registered through a call to 'RegisterNativeUserFunction', unless overridden, this method return the value
|
||||
/// of the 'factoryMethodName' argument. If overridden, it must have the same signature as the factory method.
|
||||
///
|
||||
CNTK_API virtual std::wstring DeserializeMethodName() const;
|
||||
|
||||
public:
|
||||
|
||||
// Optional overrides
|
||||
|
@ -2997,7 +3014,7 @@ namespace CNTK
|
|||
///
|
||||
/// Generates a dictionary that captures the state of the Function graph underlying this Function.
|
||||
///
|
||||
CNTK_API virtual Dictionary Serialize() const override { return Dictionary(); }
|
||||
CNTK_API virtual Dictionary Serialize() const override { return Attributes(); }
|
||||
|
||||
///
|
||||
/// Creates a clone of this Function instance, using the specified 'inputs' that are inputs of the clone to be constructed.
|
||||
|
@ -3052,8 +3069,7 @@ namespace CNTK
|
|||
/// if deserializer was omitted). If there are no user defined functions in the model, deserializer is ignored.
|
||||
///
|
||||
CNTK_API static FunctionPtr Deserialize(const Dictionary& dictionary,
|
||||
const ::CNTK::DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice(),
|
||||
const UDFDeserializeCallback& deserializer = NoOp);
|
||||
const ::CNTK::DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice());
|
||||
|
||||
public:
|
||||
///
|
||||
|
@ -3250,44 +3266,19 @@ namespace CNTK
|
|||
/// Load a Function from a model file
|
||||
///
|
||||
CNTK_API static FunctionPtr Load(const std::wstring& filepath,
|
||||
const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice(),
|
||||
const UDFDeserializeCallback& deserializer = NoOp);
|
||||
const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
|
||||
|
||||
///
|
||||
/// Load a Function from a memory buffer
|
||||
///
|
||||
CNTK_API static FunctionPtr Load(const char* buffer, size_t length,
|
||||
const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice(),
|
||||
const UDFDeserializeCallback& deserializer = NoOp);
|
||||
const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
|
||||
|
||||
///
|
||||
/// Load a Function from an istream. The legacy V1 model is not supported.
|
||||
///
|
||||
CNTK_API static FunctionPtr Load(std::istream& inputStream,
|
||||
const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice(),
|
||||
const UDFDeserializeCallback& deserializer = NoOp);
|
||||
|
||||
#ifdef SWIG // for Python interop (adds callback wrapper)
|
||||
static CNTK::FunctionPtr Load(const std::wstring& filepath,
|
||||
const CNTK::DeviceDescriptor& computeDevice,
|
||||
const CNTK::Internal::UDFDeserializeCallbackWrapper& wrapper)
|
||||
{
|
||||
using namespace std::placeholders;
|
||||
UDFDeserializeCallback callback = std::bind(&CNTK::Internal::UDFDeserializeCallbackWrapper::operator(),
|
||||
&wrapper, _1, _2, _3);
|
||||
return CNTK::Function::Load(filepath, computeDevice, callback);
|
||||
}
|
||||
|
||||
static CNTK::FunctionPtr Load(const char* buffer, size_t length,
|
||||
const CNTK::DeviceDescriptor& computeDevice,
|
||||
const CNTK::Internal::UDFDeserializeCallbackWrapper& wrapper)
|
||||
{
|
||||
using namespace std::placeholders;
|
||||
UDFDeserializeCallback callback = std::bind(&CNTK::Internal::UDFDeserializeCallbackWrapper::operator(),
|
||||
&wrapper, _1, _2, _3);
|
||||
return CNTK::Function::Load(buffer, length, computeDevice, callback);
|
||||
}
|
||||
#endif
|
||||
const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
|
||||
|
||||
///
|
||||
/// Prints the entire graph underlying this Function to stderr
|
||||
|
@ -3317,6 +3308,18 @@ namespace CNTK
|
|||
///
|
||||
CNTK_API static FunctionPtr NativeUserFunction(const std::wstring& opName, const std::vector<Variable>& operands, const Dictionary& functionConfig, const std::wstring& userFunctionInstanceName = L"");
|
||||
|
||||
///
|
||||
/// Register a callback function to be invoked when deserializing a user-defined Function with the corresponding op name.
|
||||
/// When loading a model, CNTK will try to automatically reconstruct user-defined Functions (for native functions, CNTK will
|
||||
/// invoke the same factory method, the Function op name was registered with). This method allows to override
|
||||
/// default user-defined Function deserialization behavior by specifying an op name and the corresponding callback that should be invoked
|
||||
/// to inflate the Function object.
|
||||
///
|
||||
CNTK_API static void RegisterUDFDeserializeCallback(const std::wstring& uniqueOpName, const UDFDeserializeCallback& deserializer);
|
||||
|
||||
static UDFDeserializeCallbackPtr GetUDFDeserializeCallback(const std::wstring& uniqueOpName);
|
||||
|
||||
|
||||
protected:
|
||||
static bool IsArgument(const Variable& var)
|
||||
{
|
||||
|
@ -3324,9 +3327,10 @@ namespace CNTK
|
|||
}
|
||||
|
||||
///
|
||||
/// Protected constructor for derived 'Function' types to specify the actual input and output variables for the (primitive) Function instance.
|
||||
/// Protected constructors for derived user-defined 'Function' types to specify the actual input and output variables for the (primitive) Function instance.
|
||||
///
|
||||
CNTK_API Function(const std::vector<Variable>& inputs, Dictionary&& functionConfig, const std::wstring& name = L"", const std::wstring& uid = Internal::GenerateUid(L"UserDefinedFunction"));
|
||||
CNTK_API Function(const std::vector<Variable>& inputs, const Dictionary& functionConfig, const std::wstring& name = L"");
|
||||
CNTK_API Function(const std::vector<Variable>& inputs, const std::wstring& name = L"");
|
||||
|
||||
template <typename FunctionType>
|
||||
static void PreorderTraverseFunctions(const FunctionPtr& rootFunction, const FunctionType& functor, bool traverseInsideBlockFunction = false)
|
||||
|
@ -3420,14 +3424,11 @@ namespace CNTK
|
|||
// Disallow copy and move construction and assignment
|
||||
Function(const Function&) = delete; Function(Function&&) = delete; Function& operator=(const Function&) = delete; Function& operator=(Function&&) = delete;
|
||||
|
||||
public:
|
||||
CNTK_API Function(const std::vector<Variable>& inputs, const std::wstring& name = L"");
|
||||
|
||||
private:
|
||||
static UserFunctionFactoryPtr s_userFunctionFactory;
|
||||
|
||||
private:
|
||||
CNTK_API Function(const std::vector<Variable>& inputs, Dictionary&& functionConfig, const FunctionPtr& rootFunction, const std::wstring& name, const std::wstring& uid);
|
||||
Function(const std::vector<Variable>& inputs, const Dictionary& functionConfig, const FunctionPtr& rootFunction, const std::wstring& name, const std::wstring& uid);
|
||||
|
||||
std::vector<Variable> m_inputs;
|
||||
std::once_flag m_outputsInitFlag;
|
||||
|
@ -3437,6 +3438,22 @@ namespace CNTK
|
|||
std::wstring m_name;
|
||||
std::wstring m_uid;
|
||||
Dictionary m_attributes;
|
||||
|
||||
#ifdef SWIG
|
||||
public:
|
||||
void SetNative(bool native) { m_native = native; }
|
||||
#endif
|
||||
|
||||
private:
|
||||
bool IsNative() const { return m_native; }
|
||||
|
||||
bool m_native = true;
|
||||
|
||||
Dictionary SerializeNativeImpl() const;
|
||||
|
||||
static FunctionPtr DeserializeNativeImpl(const std::vector<Variable>& inputs, const std::wstring& name, const Dictionary& dict);
|
||||
|
||||
static const size_t s_serializationVersion = 1;
|
||||
};
|
||||
|
||||
///
|
||||
|
|
|
@ -376,8 +376,6 @@ namespace CNTK
|
|||
std::wstring m_fileName;
|
||||
};
|
||||
|
||||
|
||||
#ifdef SWIG
|
||||
// SWIG callback wrapper for the UDF deserialization.
|
||||
class UDFDeserializeCallbackWrapper
|
||||
{
|
||||
|
@ -385,8 +383,13 @@ namespace CNTK
|
|||
virtual FunctionPtr operator()(const std::vector<Variable>&, const std::wstring&, const Dictionary&) const = 0;
|
||||
virtual ~UDFDeserializeCallbackWrapper() = default;
|
||||
};
|
||||
#endif
|
||||
|
||||
typedef std::shared_ptr<UDFDeserializeCallbackWrapper> UDFDeserializeCallbackWrapperPtr;
|
||||
|
||||
CNTK_API void RegisterUDFDeserializeCallbackWrapper(UDFDeserializeCallbackWrapperPtr callbackPtr);
|
||||
|
||||
|
||||
CNTK_API bool IsNativeUserFunctionRegistered(const std::wstring& uniqueOpName);
|
||||
}
|
||||
|
||||
// Forward-declare test fixtures, so that they can be used as friends.
|
||||
|
|
|
@ -137,6 +137,7 @@
|
|||
<ClInclude Include="PrimitiveOpType.h" />
|
||||
<ClInclude Include="Serialization.h" />
|
||||
<ClInclude Include="tensorboard\TensorBoardUtils.h" />
|
||||
<ClInclude Include="UserDefinedFunction.h" />
|
||||
<ClInclude Include="UserFunctionFactory.h" />
|
||||
<ClInclude Include="Utils.h" />
|
||||
<ClInclude Include="stdafx.h" />
|
||||
|
@ -175,6 +176,7 @@
|
|||
<ClCompile Include="tensorboard\TensorBoardUtils.cpp" />
|
||||
<ClCompile Include="Trainer.cpp" />
|
||||
<ClCompile Include="TrainingSession.cpp" />
|
||||
<ClCompile Include="UserDefinedFunction.cpp" />
|
||||
<ClCompile Include="Utils.cpp" />
|
||||
<ClCompile Include="Value.cpp" />
|
||||
<ClCompile Include="Variable.cpp" />
|
||||
|
|
|
@ -36,6 +36,7 @@
|
|||
</ClCompile>
|
||||
<ClCompile Include="ProgressWriter.cpp" />
|
||||
<ClCompile Include="Evaluator.cpp" />
|
||||
<ClCompile Include="UserDefinedFunction.cpp" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="stdafx.h" />
|
||||
|
@ -67,6 +68,7 @@
|
|||
<ClInclude Include="BlockFunction.h" />
|
||||
<ClInclude Include="Variable.h" />
|
||||
<ClInclude Include="UserFunctionFactory.h" />
|
||||
<ClInclude Include="UserDefinedFunction.h" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<Filter Include="API">
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "BlockFunction.h"
|
||||
#include "SpecialPurposeNodes.h"
|
||||
#include "SequenceReshapeNodes.h"
|
||||
#include "UserDefinedFunction.h"
|
||||
|
||||
using namespace Microsoft::MSR::CNTK;
|
||||
|
||||
|
@ -58,7 +59,7 @@ namespace CNTK
|
|||
for (auto& function : m_allPrimitiveFunctions)
|
||||
{
|
||||
auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(function.get());
|
||||
if (!primitiveFunction->IsStateful())
|
||||
if (!primitiveFunction || !primitiveFunction->IsStateful())
|
||||
continue;
|
||||
|
||||
// TODO: same for BatchNorm
|
||||
|
@ -85,7 +86,7 @@ namespace CNTK
|
|||
for (auto& function : m_allPrimitiveFunctions)
|
||||
{
|
||||
auto primitiveFunction = dynamic_cast<const PrimitiveFunction*>(function.get());
|
||||
if (!primitiveFunction->IsStateful())
|
||||
if (!primitiveFunction || !primitiveFunction->IsStateful())
|
||||
continue;
|
||||
|
||||
// TODO: same for BatchNorm
|
||||
|
@ -224,7 +225,7 @@ namespace CNTK
|
|||
return composite;
|
||||
}
|
||||
|
||||
/*static*/ FunctionPtr CompositeFunction::Deserialize(const Dictionary& dict, const CNTK::DeviceDescriptor& device, const UDFDeserializeCallback& callback)
|
||||
/*static*/ FunctionPtr CompositeFunction::Deserialize(const Dictionary& dict, const CNTK::DeviceDescriptor& device)
|
||||
{
|
||||
static const vector<std::wstring> s_requiredDictionaryKeys = { inputsKey, functionsKey };
|
||||
|
||||
|
@ -254,7 +255,7 @@ namespace CNTK
|
|||
{
|
||||
auto functionDict = dictionaryValue.Value<Dictionary>();
|
||||
FunctionPtr root = UDFUtils::IsUDF(functionDict) ?
|
||||
UDFUtils::Deserialize(functionDict, uidToInputMap, device, callback) :
|
||||
UDFUtils::Deserialize(functionDict, uidToInputMap, device) :
|
||||
PrimitiveFunction::Deserialize(functionDict, uidToInputMap, allPrimitiveFunctions, allPlaceholderReplacements, device);
|
||||
allPrimitiveFunctions.insert(root);
|
||||
|
||||
|
@ -303,7 +304,7 @@ namespace CNTK
|
|||
for (auto& function : functions)
|
||||
{
|
||||
auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(function.get());
|
||||
if (!primitiveFunction->IsStateful())
|
||||
if (!primitiveFunction || !primitiveFunction->IsStateful())
|
||||
continue;
|
||||
|
||||
if (stateDictionary.Contains(primitiveFunction->Uid()))
|
||||
|
@ -340,7 +341,7 @@ namespace CNTK
|
|||
vector<wstring> uids;
|
||||
PreorderTraverseFunctions(function.RootFunction(), [&uids](const FunctionPtr& funcPtr) {
|
||||
auto primitiveFunction = dynamic_cast<const PrimitiveFunction*>(funcPtr.get());
|
||||
if (primitiveFunction->IsStateful())
|
||||
if (primitiveFunction && primitiveFunction->IsStateful())
|
||||
{
|
||||
uids.push_back(funcPtr->Uid());
|
||||
}
|
||||
|
@ -384,7 +385,7 @@ namespace CNTK
|
|||
for (const auto& function : m_allPrimitiveFunctions)
|
||||
{
|
||||
auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(function.get());
|
||||
if (!primitiveFunction->IsStateful())
|
||||
if (!primitiveFunction || !primitiveFunction->IsStateful())
|
||||
continue;
|
||||
|
||||
auto functionState = state[primitiveFunction->Uid()].Value<Dictionary>();
|
||||
|
|
|
@ -119,7 +119,7 @@ namespace CNTK
|
|||
const std::unordered_map<Variable, Variable>& allPlaceholderReplacements,
|
||||
const CNTK::DeviceDescriptor& device);
|
||||
|
||||
static FunctionPtr Deserialize(const Dictionary& dictionary, const CNTK::DeviceDescriptor& device, const UDFDeserializeCallback& callback);
|
||||
static FunctionPtr Deserialize(const Dictionary& dictionary, const CNTK::DeviceDescriptor& device);
|
||||
|
||||
virtual const std::wstring& OpName() const override
|
||||
{
|
||||
|
|
|
@ -27,6 +27,30 @@ namespace CNTK
|
|||
return AsComposite(s_userFunctionFactory->CreateInstance(opName, operands, functionConfig, userFunctionInstanceName), userFunctionInstanceName);
|
||||
}
|
||||
|
||||
bool Internal::IsNativeUserFunctionRegistered(const std::wstring& uniqueOpName)
|
||||
{
|
||||
return Function::s_userFunctionFactory->IsRegistered(uniqueOpName);
|
||||
}
|
||||
|
||||
static std::unordered_map<std::wstring, UDFDeserializeCallbackPtr> udfCallbackMap;
|
||||
static std::mutex udfCallbackMapMutex;
|
||||
|
||||
/*static*/ void Function::RegisterUDFDeserializeCallback(const std::wstring& uniqueOpName, const UDFDeserializeCallback& deserializer)
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(udfCallbackMapMutex);
|
||||
auto result = udfCallbackMap.insert({ uniqueOpName, make_shared<UDFDeserializeCallback>(deserializer) });
|
||||
if (!result.second)
|
||||
InvalidArgument("A callback for the UserFunction with op name '%S' has already been registered.", uniqueOpName.c_str());
|
||||
}
|
||||
|
||||
/*static*/ UDFDeserializeCallbackPtr Function::GetUDFDeserializeCallback(const std::wstring& uniqueOpName)
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(udfCallbackMapMutex);
|
||||
if (udfCallbackMap.find(uniqueOpName) == udfCallbackMap.end())
|
||||
return nullptr;
|
||||
return udfCallbackMap.at(uniqueOpName);
|
||||
}
|
||||
|
||||
std::vector<Variable>& Function::InitOutputs()
|
||||
{
|
||||
std::call_once(m_outputsInitFlag, [this]() {
|
||||
|
@ -85,10 +109,6 @@ namespace CNTK
|
|||
return std::shared_ptr<std::vector<Variable>>(new std::vector<Variable>(std::move(inputs)), [](std::vector<Variable>* ptr) { delete ptr; });
|
||||
}
|
||||
|
||||
Function::Function(const std::vector<Variable>& inputs, Dictionary&& functionConfig, const std::wstring& name, const std::wstring& uid)
|
||||
: Function(inputs, std::move(functionConfig), nullptr, name, uid)
|
||||
{}
|
||||
|
||||
std::shared_ptr<std::vector<Variable>> Function::OutputsImpl() const
|
||||
{
|
||||
std::vector<Variable> outputs;
|
||||
|
@ -99,11 +119,7 @@ namespace CNTK
|
|||
return std::shared_ptr<std::vector<Variable>>(new std::vector<Variable>(std::move(outputs)), [](std::vector<Variable>* ptr) { delete ptr; });
|
||||
}
|
||||
|
||||
Function::Function(const std::vector<Variable>& inputs, const std::wstring& name)
|
||||
: Function(inputs, Dictionary(), name)
|
||||
{}
|
||||
|
||||
Function::Function(const std::vector<Variable>& inputs, Dictionary&& functionConfig, const FunctionPtr& rootFunction, const std::wstring& name, const std::wstring& uid)
|
||||
Function::Function(const std::vector<Variable>& inputs, const Dictionary& functionConfig, const FunctionPtr& rootFunction, const std::wstring& name, const std::wstring& uid)
|
||||
: m_rootFunction(rootFunction), m_name(name), m_uid(uid), m_attributes(std::move(functionConfig))
|
||||
{
|
||||
for (auto inputVar : inputs)
|
||||
|
@ -438,14 +454,14 @@ namespace CNTK
|
|||
stream->flush();
|
||||
}
|
||||
|
||||
/*static*/ FunctionPtr Function::Load(const std::wstring& filepath, const DeviceDescriptor& computeDevice, const UDFDeserializeCallback& callback)
|
||||
/*static*/ FunctionPtr Function::Load(const std::wstring& filepath, const DeviceDescriptor& computeDevice)
|
||||
{
|
||||
auto stream = GetFstream(filepath, true);
|
||||
if (!Internal::IsLegacyModel(*stream))
|
||||
{
|
||||
Dictionary model;
|
||||
*stream >> model;
|
||||
return Function::Deserialize(model, computeDevice, callback);
|
||||
return Function::Deserialize(model, computeDevice);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -453,7 +469,7 @@ namespace CNTK
|
|||
}
|
||||
}
|
||||
|
||||
/*static*/ FunctionPtr Function::Load(const char *buffer, size_t length, const DeviceDescriptor& computeDevice, const UDFDeserializeCallback& callback)
|
||||
/*static*/ FunctionPtr Function::Load(const char *buffer, size_t length, const DeviceDescriptor& computeDevice)
|
||||
{
|
||||
if ((buffer == nullptr) || (length <= 0))
|
||||
InvalidArgument("The model buffer should not be null and its length should be greater than 0");
|
||||
|
@ -474,15 +490,15 @@ namespace CNTK
|
|||
modelStreamBuffer buf(buffer, length);
|
||||
std::istream modelStream(&buf);
|
||||
|
||||
return Load(modelStream, computeDevice, callback);
|
||||
return Load(modelStream, computeDevice);
|
||||
}
|
||||
}
|
||||
|
||||
/*static*/ FunctionPtr Function::Load(std::istream& inputStream, const DeviceDescriptor& computeDevice, const UDFDeserializeCallback& callback)
|
||||
/*static*/ FunctionPtr Function::Load(std::istream& inputStream, const DeviceDescriptor& computeDevice)
|
||||
{
|
||||
Dictionary model;
|
||||
inputStream >> model;
|
||||
return Function::Deserialize(model, computeDevice, callback);
|
||||
return Function::Deserialize(model, computeDevice);
|
||||
}
|
||||
|
||||
void Function::Restore(const std::wstring& filepath)
|
||||
|
@ -904,9 +920,9 @@ namespace CNTK
|
|||
compositeFunction->CopyState(*restoredCompositeFunction);
|
||||
}
|
||||
|
||||
/*static*/ FunctionPtr Function::Deserialize(const Dictionary& modelDictionary, const CNTK::DeviceDescriptor& device, const UDFDeserializeCallback& callback)
|
||||
/*static*/ FunctionPtr Function::Deserialize(const Dictionary& modelDictionary, const CNTK::DeviceDescriptor& device)
|
||||
{
|
||||
return CompositeFunction::Deserialize(modelDictionary, device, callback);
|
||||
return CompositeFunction::Deserialize(modelDictionary, device);
|
||||
}
|
||||
|
||||
void Function::PrintGraph() const
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include "ConvolveGeometry.h"
|
||||
#include "ConvolutionalNodes.h"
|
||||
#include "Variable.h"
|
||||
#include "UserFunctionFactory.h"
|
||||
|
||||
using namespace Microsoft::MSR::CNTK;
|
||||
|
||||
|
@ -963,7 +964,7 @@ namespace CNTK
|
|||
}
|
||||
}
|
||||
|
||||
static vector<DictionaryValue> GetInputUids(const Function& f)
|
||||
vector<DictionaryValue> GetInputUids(const Function& f)
|
||||
{
|
||||
auto inputs = f.Inputs();
|
||||
vector<DictionaryValue> inputUids;
|
||||
|
@ -975,7 +976,7 @@ namespace CNTK
|
|||
return inputUids;
|
||||
}
|
||||
|
||||
static Dictionary SerializeCommonAttributes(const Function& f, size_t version, const wstring& functionType)
|
||||
Dictionary SerializeCommonFunctionAttributes(const Function& f, size_t version, const wstring& functionType)
|
||||
{
|
||||
Dictionary dict;
|
||||
dict[versionKey] = version;
|
||||
|
@ -991,7 +992,7 @@ namespace CNTK
|
|||
|
||||
/*virtual*/ Dictionary PrimitiveFunction::Serialize() const
|
||||
{
|
||||
Dictionary dict = SerializeCommonAttributes(*this, CurrentVersion(), s_primitiveFunctionTypeValue);
|
||||
Dictionary dict = SerializeCommonFunctionAttributes(*this, CurrentVersion(), s_primitiveFunctionTypeValue);
|
||||
dict[opKey] = static_cast<size_t>(m_op);
|
||||
dict[attributesKey] = Attributes();
|
||||
|
||||
|
@ -1018,7 +1019,7 @@ namespace CNTK
|
|||
return dict;
|
||||
}
|
||||
|
||||
static std::vector<Variable> GetInputVariables(const Dictionary& dict, const unordered_map<wstring, Variable>& uidToVariableMap, size_t currentSerializationVersion)
|
||||
std::vector<Variable> GetInputVariables(const Dictionary& dict, const std::unordered_map<std::wstring, Variable>& uidToVariableMap, size_t currentSerializationVersion)
|
||||
{
|
||||
const auto& inputUids = dict[inputsKey].Value<vector<DictionaryValue>>();
|
||||
|
||||
|
@ -1246,54 +1247,4 @@ namespace CNTK
|
|||
|
||||
return dummyOutputVariable.Shape();
|
||||
}
|
||||
|
||||
static const std::wstring s_userDefinedFunctionTypeValue = L"UserDefinedFunction";
|
||||
|
||||
/*static*/ bool UDFUtils::IsUDF(const FunctionPtr& f)
|
||||
{
|
||||
return (dynamic_cast<const PrimitiveFunction*>(f.get()) == nullptr);
|
||||
}
|
||||
|
||||
/*static*/ bool UDFUtils::IsUDF(const Dictionary& dict)
|
||||
{
|
||||
return (dict.Contains(typeKey) && dict[typeKey].Value<std::wstring>() == s_userDefinedFunctionTypeValue);
|
||||
}
|
||||
|
||||
/*static*/ Dictionary UDFUtils::Serialize(const FunctionPtr& udf)
|
||||
{
|
||||
Dictionary dict = SerializeCommonAttributes(*udf, s_serializationVersion, s_userDefinedFunctionTypeValue);
|
||||
dict[userDefinedStateKey] = udf->Serialize();
|
||||
return dict;
|
||||
}
|
||||
|
||||
/*static*/ FunctionPtr UDFUtils::Deserialize(const Dictionary& dict,
|
||||
const unordered_map<std::wstring, Variable>& uidToVariableMap,
|
||||
const DeviceDescriptor& device,
|
||||
const UDFDeserializeCallback& callback)
|
||||
{
|
||||
static const vector<std::wstring> s_requiredDictionaryKeys = { typeKey, uidKey, inputsKey, userDefinedStateKey };
|
||||
ValidateDictionary<PrimitiveFunction>(dict, s_requiredDictionaryKeys, s_userDefinedFunctionTypeValue, s_serializationVersion);
|
||||
|
||||
const auto& uid = dict[uidKey].Value<std::wstring>();
|
||||
std::wstring name = L"";
|
||||
if (dict.Contains(nameKey))
|
||||
name = dict[nameKey].Value<std::wstring>();
|
||||
|
||||
auto inputs = GetInputVariables(dict, uidToVariableMap, s_serializationVersion);
|
||||
|
||||
auto state = dict[userDefinedStateKey].Value<Dictionary>();
|
||||
|
||||
auto udf = callback(inputs, name, state);
|
||||
|
||||
if (udf == nullptr)
|
||||
{
|
||||
RuntimeError("Unable to reconstruct a user-defined function. Please make sure to specify a valid UDF deserializer.");
|
||||
}
|
||||
|
||||
// Restore the original uid, which other functions in the graph depend on
|
||||
// (their inputs refer to the uids of this UDF outputs, which are generated base on the uid of this UDF).
|
||||
udf->m_uid = uid;
|
||||
|
||||
return udf;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -263,7 +263,7 @@ namespace CNTK
|
|||
|
||||
protected:
|
||||
PrimitiveFunction(PrimitiveOpType op, const std::vector<Variable>& inputs, Dictionary&& functionConfig, const std::wstring& functionName, const std::wstring& uid)
|
||||
: Function(inputs, std::move(functionConfig), functionName, uid), m_op(op)
|
||||
: Function(inputs, std::move(functionConfig), nullptr, functionName, uid), m_op(op)
|
||||
{}
|
||||
|
||||
public:
|
||||
|
@ -759,21 +759,7 @@ namespace CNTK
|
|||
static const size_t s_serializationVersion = 13;
|
||||
};
|
||||
|
||||
class UDFUtils
|
||||
{
|
||||
public:
|
||||
|
||||
static bool IsUDF(const FunctionPtr& f);
|
||||
|
||||
static bool IsUDF(const Dictionary& dict);
|
||||
|
||||
static Dictionary Serialize(const FunctionPtr& dictionary);
|
||||
|
||||
static FunctionPtr Deserialize(const Dictionary& dictionary,
|
||||
const std::unordered_map<std::wstring, Variable>& uidToVariableMap,
|
||||
const CNTK::DeviceDescriptor& device,
|
||||
const UDFDeserializeCallback& callback);
|
||||
|
||||
static const size_t s_serializationVersion = 0;
|
||||
};
|
||||
std::vector<DictionaryValue> GetInputUids(const Function& f);
|
||||
Dictionary SerializeCommonFunctionAttributes(const Function& f, size_t version, const std::wstring& functionType);
|
||||
std::vector<Variable> GetInputVariables(const Dictionary& dict, const std::unordered_map<std::wstring, Variable>& uidToVariableMap, size_t currentSerializationVersion);
|
||||
}
|
||||
|
|
|
@ -45,6 +45,9 @@ namespace CNTK
|
|||
const std::wstring internalWorkerStateKey = L"internal_worker_state";
|
||||
const std::wstring externalWorkerStateKey = L"external_worker_state";
|
||||
const std::wstring userDefinedStateKey = L"user_defined_state";
|
||||
const std::wstring udfModuleNameKey = L"module";
|
||||
const std::wstring udfFactoryMethodNameKey = L"deserialize_method";
|
||||
const std::wstring nativeUDFKey = L"native";
|
||||
|
||||
template <typename T>
|
||||
inline std::string GetVersionsString(size_t currentVersion, size_t dictVersion)
|
||||
|
|
|
@ -0,0 +1,178 @@
|
|||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
|
||||
#include "UserDefinedFunction.h"
|
||||
#include "UserFunctionFactory.h"
|
||||
#include "Serialization.h"
|
||||
#include "PrimitiveFunction.h"
|
||||
#include "CompositeFunction.h"
|
||||
|
||||
|
||||
namespace CNTK
|
||||
{
|
||||
|
||||
static Internal::UDFDeserializeCallbackWrapperPtr s_SWIGCallbackWrapper;
|
||||
|
||||
void Internal::RegisterUDFDeserializeCallbackWrapper(UDFDeserializeCallbackWrapperPtr callbackPtr)
|
||||
{
|
||||
s_SWIGCallbackWrapper = callbackPtr;
|
||||
}
|
||||
|
||||
static const std::wstring s_nativeUDFTypeValue = L"NativeUserDefinedFunction" ;
|
||||
|
||||
static std::unordered_map<std::wstring, std::pair<std::wstring, std::wstring>> s_deserializedUDFsRegistry;
|
||||
|
||||
Function::Function(const std::vector<Variable>& inputs, const std::wstring& name)
|
||||
: Function(inputs, {}, name)
|
||||
{}
|
||||
|
||||
Function::Function(const std::vector<Variable>& inputs, const Dictionary& functionConfig, const std::wstring& name)
|
||||
: Function(inputs, functionConfig, nullptr, name, Internal::GenerateUid(L"UserDefinedFunction"))
|
||||
{}
|
||||
|
||||
/*static*/ FunctionPtr Function::DeserializeNativeImpl(const std::vector<Variable>& inputs, const std::wstring& name, const Dictionary& dict)
|
||||
{
|
||||
static const vector<std::wstring> s_requiredDictionaryKeys = { userDefinedStateKey, udfModuleNameKey, udfFactoryMethodNameKey, opKey };
|
||||
ValidateDictionary<PrimitiveFunction>(dict, s_requiredDictionaryKeys, s_nativeUDFTypeValue, s_serializationVersion);
|
||||
|
||||
auto state = dict[userDefinedStateKey].Value<Dictionary>();
|
||||
auto opName = dict[opKey].Value<wstring>();
|
||||
auto moduleName = dict[udfModuleNameKey].Value<wstring>();
|
||||
auto methodName = dict[udfFactoryMethodNameKey].Value<wstring>();
|
||||
|
||||
FunctionPtr udf = nullptr;
|
||||
|
||||
auto callback = Function::GetUDFDeserializeCallback(opName);
|
||||
if (callback != nullptr)
|
||||
{
|
||||
udf = callback->operator()(inputs, name, state);
|
||||
}
|
||||
else
|
||||
{
|
||||
Microsoft::MSR::CNTK::Plugin plugin;
|
||||
auto factoryMethod = (UserFunctionFactoryMethodType)(plugin.Load(moduleName, msra::strfun::utf8(methodName), /*isCNTKPlugin =*/ false));
|
||||
udf = FunctionPtr(factoryMethod(inputs.data(), inputs.size(), &state, name.c_str()));
|
||||
}
|
||||
|
||||
if (udf == nullptr)
|
||||
{
|
||||
RuntimeError("Unable to reconstruct the native UserFunction with op name '%S'", opName.c_str());
|
||||
}
|
||||
|
||||
s_deserializedUDFsRegistry[opName] = { moduleName, methodName };
|
||||
|
||||
return udf;
|
||||
}
|
||||
|
||||
/*virtual*/ std::wstring Function::ModuleName() const
|
||||
{
|
||||
auto it = s_deserializedUDFsRegistry.find(OpName());
|
||||
if (it != s_deserializedUDFsRegistry.end())
|
||||
{
|
||||
auto moduleAndMethodPair = it->second;
|
||||
return moduleAndMethodPair.first;
|
||||
}
|
||||
|
||||
// this op name was never registered in the s_deserializedUDFsRegistry (which only happens during the deserialization),
|
||||
// then use user factory as a fallback (this udf must have been registed, so that an instance could be created).
|
||||
return s_userFunctionFactory->GetModuleName(OpName());
|
||||
}
|
||||
|
||||
/*virtual*/ std::wstring Function::DeserializeMethodName() const
|
||||
{
|
||||
auto it = s_deserializedUDFsRegistry.find(OpName());
|
||||
if (it != s_deserializedUDFsRegistry.end())
|
||||
{
|
||||
auto moduleAndMethodPair = it->second;
|
||||
return moduleAndMethodPair.second;
|
||||
}
|
||||
|
||||
// this op name was never registered in the s_deserializedUDFsRegistry (which only happens during the deserialization),
|
||||
// then use user factory as a fallback (this udf must have been registed, so that an instance could be created).
|
||||
return Function::s_userFunctionFactory->GetFactoryMethodName(OpName());
|
||||
}
|
||||
|
||||
Dictionary Function::SerializeNativeImpl() const
|
||||
{
|
||||
Dictionary dict;
|
||||
dict[userDefinedStateKey] = Serialize();
|
||||
dict[udfModuleNameKey] = ModuleName();
|
||||
dict[udfFactoryMethodNameKey] = DeserializeMethodName();
|
||||
dict[opKey] = OpName();
|
||||
dict[versionKey] = s_serializationVersion;
|
||||
dict[typeKey] = s_nativeUDFTypeValue;
|
||||
return dict;
|
||||
}
|
||||
|
||||
static const std::wstring s_userDefinedFunctionTypeValue = L"UserDefinedFunction";
|
||||
|
||||
/*static*/ bool UDFUtils::IsUDF(const FunctionPtr& f)
|
||||
{
|
||||
return (dynamic_cast<const PrimitiveFunction*>(f.get()) == nullptr) &&
|
||||
(dynamic_cast<const CompositeFunction*>(f.get()) == nullptr);
|
||||
}
|
||||
|
||||
/*static*/ bool UDFUtils::IsUDF(const Dictionary& dict)
|
||||
{
|
||||
return (dict.Contains(typeKey) && dict[typeKey].Value<std::wstring>() == s_userDefinedFunctionTypeValue);
|
||||
}
|
||||
|
||||
/*static*/ bool UDFUtils::IsNativeUDF(const Dictionary& dict)
|
||||
{
|
||||
assert(IsUDF(dict));
|
||||
return (dict.Contains(nativeUDFKey) && dict[nativeUDFKey].Value<bool>() == true);
|
||||
}
|
||||
|
||||
/*static*/ Dictionary UDFUtils::Serialize(const FunctionPtr& f)
|
||||
{
|
||||
Dictionary dict = SerializeCommonFunctionAttributes(*f, s_serializationVersion, s_userDefinedFunctionTypeValue);
|
||||
bool native = f->IsNative();
|
||||
dict[nativeUDFKey] = native;
|
||||
dict[userDefinedStateKey] = (native) ? f->SerializeNativeImpl() : f->Serialize();
|
||||
return dict;
|
||||
}
|
||||
|
||||
/*static*/ FunctionPtr UDFUtils::Deserialize(const Dictionary& dict,
|
||||
const unordered_map<std::wstring, Variable>& uidToVariableMap,
|
||||
const DeviceDescriptor& device)
|
||||
{
|
||||
static const vector<std::wstring> s_requiredDictionaryKeys = { typeKey, uidKey, inputsKey, userDefinedStateKey };
|
||||
ValidateDictionary<PrimitiveFunction>(dict, s_requiredDictionaryKeys, s_userDefinedFunctionTypeValue, s_serializationVersion);
|
||||
|
||||
const auto& uid = dict[uidKey].Value<std::wstring>();
|
||||
std::wstring name = L"";
|
||||
if (dict.Contains(nameKey))
|
||||
name = dict[nameKey].Value<std::wstring>();
|
||||
|
||||
auto inputs = GetInputVariables(dict, uidToVariableMap, s_serializationVersion);
|
||||
|
||||
auto state = dict[userDefinedStateKey].Value<Dictionary>();
|
||||
|
||||
FunctionPtr udf;
|
||||
|
||||
if (IsNativeUDF(dict))
|
||||
{
|
||||
udf = Function::DeserializeNativeImpl(inputs, name, state);
|
||||
}
|
||||
else if (s_SWIGCallbackWrapper != nullptr)
|
||||
{
|
||||
// If we're being called from SWIG, the actual deserializer should be registered by
|
||||
// the target language CNTK implementation (i.e., cnkt_py for Python)
|
||||
udf = s_SWIGCallbackWrapper->operator()(inputs, name, state);
|
||||
}
|
||||
|
||||
if (udf == nullptr)
|
||||
{
|
||||
RuntimeError("Unable to reconstruct a user-defined function (name = %S, uid = %S). "
|
||||
"Please make sure to specify a valid UDF deserializer.", name.c_str(), uid.c_str());
|
||||
}
|
||||
|
||||
// Restore the original uid, which other functions in the graph depend on
|
||||
// (their inputs refer to the uids of this UDF outputs, which are generated base on the uid of this UDF).
|
||||
udf->m_uid = uid;
|
||||
|
||||
return udf;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "stdafx.h"
|
||||
#include "CNTKLibrary.h"
|
||||
|
||||
namespace CNTK
|
||||
{
|
||||
class UDFUtils
|
||||
{
|
||||
public:
|
||||
|
||||
static bool IsUDF(const FunctionPtr& f);
|
||||
|
||||
static bool IsUDF(const Dictionary& dict);
|
||||
|
||||
static Dictionary Serialize(const FunctionPtr& f);
|
||||
|
||||
static FunctionPtr Deserialize(const Dictionary& dictionary,
|
||||
const std::unordered_map<std::wstring, Variable>& uidToVariableMap,
|
||||
const CNTK::DeviceDescriptor& device);
|
||||
|
||||
static const size_t s_serializationVersion = 0;
|
||||
|
||||
private:
|
||||
static bool IsNativeUDF(const Dictionary& dict);
|
||||
};
|
||||
}
|
|
@ -11,30 +11,60 @@
|
|||
|
||||
namespace CNTK
|
||||
{
|
||||
typedef Function* (*UserFunctionFactoryMethodType)(const Variable* operands, size_t numOperands, const Dictionary* attributes, const wchar_t* name);
|
||||
|
||||
class UserFunctionFactory : public std::enable_shared_from_this<UserFunctionFactory>
|
||||
{
|
||||
typedef Function* (*UserFunctionFactoryMethodType)(const Variable* operands, size_t numOperands, const Dictionary* attributes, const wchar_t* name);
|
||||
|
||||
public:
|
||||
|
||||
bool IsRegistered(const std::wstring& uniqueOpName)
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(m_mutex);
|
||||
return IsRegisteredImpl(uniqueOpName);
|
||||
}
|
||||
|
||||
void Register(const std::wstring& uniqueOpName, const std::wstring& moduleName, const std::wstring& factoryMethodName)
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(m_mutex);
|
||||
if (m_userFunctionFactoryMethodMap.find(uniqueOpName) != m_userFunctionFactoryMethodMap.end())
|
||||
if (IsRegisteredImpl(uniqueOpName))
|
||||
InvalidArgument("UserFunction with op name '%S' is already registered. All UserFunction op names must be unique.", uniqueOpName.c_str());
|
||||
|
||||
m_userFunctionFactoryMethodMap[uniqueOpName] = std::make_shared<UserFunctionFactoryMethodInfo>(moduleName, factoryMethodName);
|
||||
}
|
||||
|
||||
std::wstring GetModuleName(const std::wstring& uniqueOpName)
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(m_mutex);
|
||||
if (!IsRegisteredImpl(uniqueOpName))
|
||||
InvalidArgument("UserFunction with op name '%S' has not been registered.", uniqueOpName.c_str());
|
||||
|
||||
return m_userFunctionFactoryMethodMap.at(uniqueOpName)->m_moduleName;
|
||||
}
|
||||
|
||||
std::wstring GetFactoryMethodName(const std::wstring& uniqueOpName)
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(m_mutex);
|
||||
if (!IsRegisteredImpl(uniqueOpName))
|
||||
InvalidArgument("UserFunction with op name '%S' has not been registered.", uniqueOpName.c_str());
|
||||
|
||||
return m_userFunctionFactoryMethodMap.at(uniqueOpName)->m_factoryMethodName;
|
||||
}
|
||||
|
||||
FunctionPtr CreateInstance(const std::wstring& opName, const std::vector<Variable>& inputs, const Dictionary& functionConfig, const std::wstring& userFunctionInstanceName)
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(m_mutex);
|
||||
if (m_userFunctionFactoryMethodMap.find(opName) == m_userFunctionFactoryMethodMap.end())
|
||||
if (!IsRegisteredImpl(opName))
|
||||
InvalidArgument("UserFunction with op name '%S' has not been registered.", opName.c_str());
|
||||
|
||||
return std::shared_ptr<Function>(m_userFunctionFactoryMethodMap.at(opName)->m_factoryMethod(inputs.data(), inputs.size(), &functionConfig, userFunctionInstanceName.c_str()));
|
||||
}
|
||||
|
||||
private:
|
||||
bool IsRegisteredImpl(const std::wstring& uniqueOpName)
|
||||
{
|
||||
return (m_userFunctionFactoryMethodMap.find(uniqueOpName) != m_userFunctionFactoryMethodMap.end());
|
||||
}
|
||||
|
||||
struct UserFunctionFactoryMethodInfo : public std::enable_shared_from_this<UserFunctionFactoryMethodInfo>
|
||||
{
|
||||
UserFunctionFactoryMethodInfo(const std::wstring& moduleName, const std::wstring& factoryMethodName)
|
||||
|
@ -53,4 +83,4 @@ namespace CNTK
|
|||
std::mutex m_mutex;
|
||||
std::unordered_map<std::wstring, UserFunctionFactoryMethodInfoPtr> m_userFunctionFactoryMethodMap;
|
||||
};
|
||||
}
|
||||
}
|
|
@ -1927,25 +1927,34 @@ SWIG_STD_VECTOR_ENHANCED(CNTK::DeviceDescriptor)
|
|||
|
||||
|
||||
|
||||
%ignore CNTK::Function::Load(const std::wstring& filepath, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice(), const Internal::UDFDeserializerPtr& deserializer);
|
||||
%ignore CNTK::Function::Load(const char* buffer, size_t length, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice(), const Internal::UDFDeserializerPtr& deserializer = nullptr);
|
||||
%ignore CNTK::Function::Load(const std::wstring& filepath, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
|
||||
%ignore CNTK::Function::Load(const char* buffer, size_t length, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
|
||||
// Ignore exposing istream to C# for now. Todo: find a good solution to map C# System.IO.Stream to std::istream.
|
||||
%ignore CNTK::Function::Load(std::istream& inputStream, const DeviceDescriptor& computeDevice= DeviceDescriptor::UseDefaultDevice(), const Internal::UDFDeserializerPtr& deserializer = nullptr);
|
||||
%ignore CNTK::Function::Load(std::istream& inputStream, const DeviceDescriptor& computeDevice= DeviceDescriptor::UseDefaultDevice());
|
||||
|
||||
%extend CNTK::Function {
|
||||
static FunctionPtr Load(const std::wstring& filepath,
|
||||
const CNTK::DeviceDescriptor& computeDevice = CNTK::DeviceDescriptor::UseDefaultDevice())
|
||||
{
|
||||
return CNTK::Function::Load(filepath, computeDevice, nullptr);
|
||||
return CNTK::Function::Load(filepath, computeDevice);
|
||||
}
|
||||
|
||||
static FunctionPtr Load(const char* modelBuffer, size_t length,
|
||||
const CNTK::DeviceDescriptor& computeDevice = CNTK::DeviceDescriptor::UseDefaultDevice())
|
||||
{
|
||||
return CNTK::Function::Load(modelBuffer, length, computeDevice, nullptr);
|
||||
return CNTK::Function::Load(modelBuffer, length, computeDevice);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
%ignore CNTK::Function::RegisterUDFDeserializeCallback;
|
||||
%ignore CNTK::Function::GetUDFDeserializeCallback;
|
||||
%ignore_class CNTK::Internal::UDFDeserializeCallbackWrapper;
|
||||
%ignore_function CNTK::Internal::RegisterUDFDeserializeCallbackWrapper;
|
||||
%ignore_function CNTK::Internal::IsNativeUserFunctionRegistered;
|
||||
|
||||
|
||||
|
||||
%include "CNTKLibraryInternals.h"
|
||||
%include "CNTKLibrary.h"
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
%rename(_add_progress_writers) CNTK::Internal::AddProgressWriters;
|
||||
%rename(_backward) CNTK::Function::Backward;
|
||||
%rename(_infer_outputs) CNTK::Function::InferOutputs;
|
||||
%rename(_serialize) CNTK::Function::Serialize;
|
||||
%rename(_serialize_impl) CNTK::Function::Serialize;
|
||||
%rename(_deserialize) CNTK::Function::Deserialize;
|
||||
%rename(_update) CNTK::Learner::Update;
|
||||
%rename(sgd_learner) CNTK::SGDLearner;
|
||||
|
@ -39,9 +39,9 @@
|
|||
%rename(ctf_deserializer) CNTK::CTFDeserializer;
|
||||
%rename(htk_feature_deserializer) CNTK::HTKFeatureDeserializer;
|
||||
%rename(htk_mlf_deserializer) CNTK::HTKMLFDeserializer;
|
||||
%rename(_infer_outputs) CNTK::Function::InferOutputs;
|
||||
%rename(_stream_infos) CNTK::SwigMinibatchSource::StreamInfos(PyObject*);
|
||||
%rename(_next_minibatch) CNTK::SwigMinibatchSource::_GetNextMinibatch;
|
||||
%rename(_register_udf_deserialize_callback) CNTK::Internal::RegisterUDFDeserializeCallbackWrapper;
|
||||
|
||||
%rename(_none) CNTK::DictionaryValue::Type::None;
|
||||
|
||||
|
@ -174,7 +174,8 @@
|
|||
%ignore CNTK::Internal::TensorBoardFileWriter::TensorBoardFileWriter(const std::wstring& dir, const ::Microsoft::MSR::CNTK::ComputationNetworkPtr& modelToVisualize = nullptr);
|
||||
%ignore CNTK::Internal::Convolution;
|
||||
|
||||
%ignore CNTK::Function::Function(const std::vector<Variable>& inputs, Dictionary&& functionConfig, const std::wstring& name = L"", const std::wstring& uid = Internal::GenerateUid(L"UserDefinedFunction"));
|
||||
%ignore CNTK::Function::RegisterUDFDeserializeCallback;
|
||||
%ignore CNTK::Function::GetUDFDeserializeCallback;
|
||||
|
||||
%{
|
||||
#define SWIG_FILE_WITH_INIT
|
||||
|
@ -617,17 +618,32 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
// Callback support
|
||||
|
||||
%feature("director") CNTK::Function;
|
||||
%feature("director") CNTK::Internal::UDFDeserializeCallbackWrapper;
|
||||
%feature("nodirector") CNTK::Function::OnPlaceholdersReplaced;
|
||||
%feature("nodirector") CNTK::Function::OpName;
|
||||
// Callback support
|
||||
%feature("director") CNTK::Internal::UDFDeserializeCallbackWrapper;
|
||||
|
||||
%typemap(directorout) std::shared_ptr<CNTK::Function> (void * swig_argp, int swig_res = 0) {
|
||||
if ($input == Py_None) {
|
||||
$result = $ltype();
|
||||
} else {
|
||||
swig_res = SWIG_ConvertPtr($input, &swig_argp, $descriptor(std::shared_ptr<CNTK::Function> *), %convertptr_flags);
|
||||
if (!SWIG_IsOK(swig_res)) {
|
||||
%dirout_fail(swig_res,"$type");
|
||||
}
|
||||
$result = *(%reinterpret_cast(swig_argp, $<ype));
|
||||
}
|
||||
}
|
||||
|
||||
// Since there're three overloads of Function::Load, both "rename" and "compactdefaultargs" are needed
|
||||
// to make pybuffer_binary work correctly with default arguments.
|
||||
%rename(load_from_buffer) CNTK::Function::Load(const char*, size_t, const DeviceDescriptor&);
|
||||
%feature("compactdefaultargs") CNTK::Function::Load(const char*, size_t, const DeviceDescriptor&);
|
||||
|
||||
// This overload is not used in python at the moment.
|
||||
%ignore CNTK::Function::Load(const std::wstring&, const DeviceDescriptor&, const UDFDeserializeCallback&);
|
||||
%ignore CNTK::Function::Load(const char*, size_t, const DeviceDescriptor&, const UDFDeserializeCallback&);
|
||||
%ignore CNTK::Function::Load(std::istream&, const DeviceDescriptor&, const UDFDeserializeCallback&);
|
||||
%rename(load_from_buffer) CNTK::Function::Load(const char*, size_t, const CNTK::DeviceDescriptor&, const CNTK::Internal::UDFDeserializeCallbackWrapper&);
|
||||
%ignore CNTK::Function::Load(std::istream&, const DeviceDescriptor&);
|
||||
|
||||
%feature("director") CNTK::Learner;
|
||||
%feature("nodirector") CNTK::Learner::Parameters;
|
||||
|
@ -812,7 +828,7 @@ public:
|
|||
}
|
||||
%enddef
|
||||
|
||||
// Implementing typemapping for UDFDeserializeCallback (it has a dictionary
|
||||
// Implementing typemapping for UDFDeserializeCallbackWrapper (it has a dictionary
|
||||
// as one of its input parameters), which needs to be implemented in Python.
|
||||
|
||||
%typemap(directorin, fragment="DictionaryValueToPy") const CNTK::Dictionary&
|
||||
|
@ -1403,6 +1419,7 @@ std::unordered_map<CNTK::StreamInformation, std::pair<CNTK::NDArrayViewPtr, CNTK
|
|||
%shared_ptr(CNTK::DistributedLearner)
|
||||
%shared_ptr(CNTK::Internal::TensorBoardFileWriter)
|
||||
%shared_ptr(CNTK::ProgressWriter)
|
||||
%shared_ptr(CNTK::Internal::UDFDeserializeCallbackWrapper)
|
||||
|
||||
%include "CNTKLibraryInternals.h"
|
||||
%include "CNTKLibrary.h"
|
||||
|
@ -1784,6 +1801,7 @@ namespace CNTK
|
|||
%template(random_uniform_float) CNTK::NDArrayView::RandomUniform<float>;
|
||||
%template(random_uniform_double) CNTK::NDArrayView::RandomUniform<double>;
|
||||
%template(DictionaryValueFromDict) CNTK::DictionaryValue::DictionaryValue<CNTK::Dictionary>;
|
||||
%template(DictionaryValueFromNDArrayView) CNTK::DictionaryValue::DictionaryValue<CNTK::NDArrayView>;
|
||||
|
||||
%template(training_parameter_per_sample_schedule) CNTK::TrainingParameterPerUnitSchedule<double, CNTK::TrainingParameterSchedule<double>::UnitType::Sample>;
|
||||
%template(training_parameter_per_minibatch_schedule) CNTK::TrainingParameterPerUnitSchedule<double, CNTK::TrainingParameterSchedule<double>::UnitType::Minibatch>;
|
||||
|
|
|
@ -23,31 +23,38 @@ def _value_as_sequence_or_array(val, var):
|
|||
map_if_possible(val)
|
||||
return val.asarray()
|
||||
|
||||
_serialization_version = 1
|
||||
|
||||
def _serialize(udf):
|
||||
dictionary = {}
|
||||
dictionary['class'] = udf.__class__.__name__
|
||||
dictionary['module'] = udf.__class__.__module__
|
||||
dictionary['op_name'] = udf.op_name
|
||||
dictionary['state'] = udf.serialize()
|
||||
dictionary['version'] = _serialization_version
|
||||
return dictionary
|
||||
|
||||
|
||||
class _UDFDeserializeCallbackWrapper(cntk_py.UDFDeserializeCallbackWrapper):
|
||||
'''
|
||||
Provides an implementation of the UDFDeserializer interface used
|
||||
to inflate user defined functions in a model dictionary.
|
||||
'''
|
||||
def __init__(self, factory_callback_map=None):
|
||||
super(_UDFDeserializeCallbackWrapper, self).__init__()
|
||||
self.factory_callback_map = factory_callback_map
|
||||
self.__disown__()
|
||||
|
||||
def __call__(self, inputs, name, dictionary):
|
||||
import pdb; pdb.set_trace()
|
||||
cls = dictionary['class']
|
||||
module = dictionary['module']
|
||||
state = dictionary['state']
|
||||
op_name = dictionary['op_name']
|
||||
deserialize_method = 'deserialize'
|
||||
|
||||
if (self.factory_callback_map and op_name in self.factory_callback_map):
|
||||
factory = self.factory_callback_map[op_name]
|
||||
else:
|
||||
exec("from {} import {}".format(module, cls))
|
||||
eval_str = "{0}.deserialize if hasattr({0}, 'deserialize') else None"
|
||||
factory = eval(eval_str.format(cls))
|
||||
eval_str = "{0}.{1} if hasattr({0}, '{1}') else None"
|
||||
factory = eval(eval_str.format(cls, deserialize_method))
|
||||
|
||||
if (factory):
|
||||
if factory:
|
||||
return factory(list(inputs), name, state)
|
||||
|
||||
raise ValueError("Cannot deserialize user function '{}.{}'. "
|
||||
|
|
|
@ -18,6 +18,15 @@ def is_string(s):
|
|||
'''
|
||||
return isinstance(s, ("".__class__, u"".__class__))
|
||||
|
||||
def is_byte_buffer(s):
|
||||
'''
|
||||
Tests whether ``s`` is a byte buffer (not a string) in a way that
|
||||
works on Python 2 and 3.
|
||||
'''
|
||||
return (isinstance(s, bytearray) or
|
||||
(isinstance(s, type(b'')) and not isinstance(b'', str)))
|
||||
|
||||
|
||||
def _as_tuple(x):
|
||||
'''
|
||||
Convert an argument to a tuple.
|
||||
|
|
|
@ -6,6 +6,8 @@
|
|||
|
||||
from .. import cntk_py
|
||||
import numpy as np
|
||||
from cntk import NDArrayView
|
||||
from ..cntk_py import DictionaryValueFromDict, DictionaryValue, Dictionary, DictionaryValueFromNDArrayView
|
||||
|
||||
_VARIABLE_OR_FUNCTION = (cntk_py.Variable, cntk_py.Function)
|
||||
|
||||
|
@ -206,7 +208,6 @@ def _py_dict_to_cntk_dict(py_dict):
|
|||
cntk_py.Dictionary:
|
||||
A :class:`~cntk.cntk_py.Dictionary` that has been converted from the input `dict`
|
||||
'''
|
||||
from ..cntk_py import DictionaryValueFromDict, DictionaryValue, Dictionary
|
||||
def _to_cntk_dict_value(py_value):
|
||||
if isinstance(py_value, dict):
|
||||
return DictionaryValueFromDict(_py_dict_to_cntk_dict(py_value))
|
||||
|
@ -214,7 +215,14 @@ def _py_dict_to_cntk_dict(py_dict):
|
|||
if isinstance(py_value, list):
|
||||
py_list = list(map(_to_cntk_dict_value, py_value))
|
||||
return DictionaryValue(py_list)
|
||||
|
||||
|
||||
if isinstance(py_value, np.ndarray):
|
||||
py_value = NDArrayView.from_dense(py_value)
|
||||
return DictionaryValueFromNDArrayView(py_value)
|
||||
|
||||
if py_value is None:
|
||||
return DictionaryValue()
|
||||
|
||||
return DictionaryValue(py_value)
|
||||
|
||||
res = Dictionary()
|
||||
|
|
|
@ -13,7 +13,8 @@ from cntk.internal import map_if_possible, typemap, sanitize_var_map,\
|
|||
_value_as_sequence_or_array
|
||||
from cntk.internal.utils import get_python_function_arguments, \
|
||||
map_function_arguments, _py_dict_to_cntk_dict
|
||||
from cntk.internal import _UDFDeserializeCallbackWrapper
|
||||
from cntk.internal import _UDFDeserializeCallbackWrapper, _serialize
|
||||
from cntk.internal.sanitize import is_byte_buffer
|
||||
from ..variables import Record, Variable
|
||||
|
||||
|
||||
|
@ -90,6 +91,10 @@ class Function(cntk_py.Function):
|
|||
in :func:`~cntk.ops.as_block()`, using the supplied ``op_name`` and ``name`` parameters, which are otherwise ignored.
|
||||
'''
|
||||
|
||||
_udf_callback_map = {}
|
||||
_deserializer = _UDFDeserializeCallbackWrapper(_udf_callback_map)
|
||||
cntk_py._register_udf_deserialize_callback(_deserializer)
|
||||
|
||||
# We override the constructors to implement an overload that constructs
|
||||
# a CNTK Functions from a Python function (@Function).
|
||||
def __new__(cls, *args, **kwargs):
|
||||
|
@ -1135,25 +1140,43 @@ class Function(cntk_py.Function):
|
|||
'restore(...) instead', DeprecationWarning)
|
||||
return self.restore(filename)
|
||||
|
||||
@staticmethod
|
||||
def register_udf_deserialize_callback(op_name, callback):
|
||||
'''
|
||||
Register a callback function to be invoked when deserializing a user-
|
||||
defined function with the corresponding op name.
|
||||
|
||||
When loading a model, CNTK will try to automatically reconstruct any
|
||||
(non-native) user-defined functions by invoking a static
|
||||
:func:`~cntk.ops.functions.UserFunction.deserialize` method of the
|
||||
corresponding UserFunction sub-class. This method allows to override
|
||||
default UDF deserialization behavior by specifying a user- defined
|
||||
function op name and the corresponding callback that should be invoked
|
||||
instead of the ``deserialize`` method.
|
||||
|
||||
Args:
|
||||
op_name (str): unique op name of the user-defined function.
|
||||
callback (function): a function taking three arguments (a list of
|
||||
inputs to the UserFunction, a string name, and a state dictionary
|
||||
generated by the corresponding :func:`~cntk.ops.functions.UserFunction.serialize`
|
||||
method) and returns an instance of the user-defined function.
|
||||
'''
|
||||
if op_name in Function._udf_callback_map:
|
||||
raise ValueError("A callback for the UserFunction with op name {}"
|
||||
" has already been registered.".format(op_name));
|
||||
Function._udf_callback_map[op_name] = callback
|
||||
|
||||
@staticmethod
|
||||
@typemap
|
||||
def load(model, device=None, udf_factory_callback_map=None):
|
||||
def load(model, device=None):
|
||||
'''
|
||||
Load the ``model``, that has been saved using :func:`~cntk.ops.functions.Function.save`.
|
||||
|
||||
Args:
|
||||
model (str or bytes): either a filepath of a model file or a byte buffer
|
||||
model (str, bytes or bytearray): either a file path of a model file or a byte buffer
|
||||
containing the binary representation of a model.
|
||||
device (:class:`~cntk.device.DeviceDescriptor`, defaults to the current globally default device):
|
||||
specifies the device to allocate the model on.
|
||||
udf_factory_callback_map (dict, default is `None`): if the model contains any user-defined
|
||||
functions, CNTK will try to automatically reconstruct them by invoking a static
|
||||
``deserialize`` method of the corresponding Function sub-class. This method takes three
|
||||
arguments (a list of inputs to the function, a string name, and a state dictionary
|
||||
generated by the corresponding :func:`~cntk.ops.functions.UserFunction.serialize` method) and
|
||||
returns an instance of the user-defined function. This optional argument allows to override
|
||||
default UDF deserialization behavior by providing a map of user-function op names and
|
||||
corresponding lambdas that should be invoked instead of the ``deserialize`` method.
|
||||
|
||||
Returns:
|
||||
root node
|
||||
|
@ -1161,10 +1184,7 @@ class Function(cntk_py.Function):
|
|||
if not device:
|
||||
device = DeviceDescriptor.use_default_device()
|
||||
|
||||
deserializer = _UDFDeserializeCallbackWrapper(udf_factory_callback_map)
|
||||
|
||||
is_buffer = isinstance(model, type(b'')) and not isinstance(b'', str)
|
||||
is_buffer = is_buffer or isinstance(model, bytearray)
|
||||
is_buffer = is_byte_buffer(model)
|
||||
|
||||
is_file = False
|
||||
if not is_buffer:
|
||||
|
@ -1174,10 +1194,10 @@ class Function(cntk_py.Function):
|
|||
pass
|
||||
|
||||
if is_buffer:
|
||||
return cntk_py.Function.load_from_buffer(model, device, deserializer)
|
||||
return cntk_py.Function.load_from_buffer(model, device)
|
||||
|
||||
if is_file:
|
||||
return cntk_py.Function.load(model, device, deserializer)
|
||||
return cntk_py.Function.load(model, device)
|
||||
|
||||
raise ValueError('Cannot load a model that is neither a file nor a byte buffer.')
|
||||
|
||||
|
@ -1225,11 +1245,11 @@ def native_user_function(op_name, operands, attributes=None, user_function_insta
|
|||
return cntk_py.Function_native_user_function(op_name, operands, attributes, user_function_instance_name)
|
||||
|
||||
@typemap
|
||||
def load_model(model, device=None, udf_factory_callback_map=None):
|
||||
def load_model(model, device=None):
|
||||
'''
|
||||
Alias for :func:`~cntk.ops.functions.Function.load`.
|
||||
'''
|
||||
return Function.load(model, device, udf_factory_callback_map)
|
||||
return Function.load(model, device)
|
||||
|
||||
@typemap
|
||||
def save_model(model, filename): # legacy name
|
||||
|
@ -1251,8 +1271,10 @@ class UserFunction(Function):
|
|||
`False` passes the data as CNTK Value objects.
|
||||
name (str): name of this function
|
||||
'''
|
||||
|
||||
def __init__(self, inputs, as_numpy=True, name=''):
|
||||
super(UserFunction, self).__init__(inputs, name)
|
||||
self.set_native(False)
|
||||
self.as_numpy = as_numpy
|
||||
|
||||
# Since the state will frequently not be used, we cache the None-state
|
||||
|
@ -1402,14 +1424,36 @@ class UserFunction(Function):
|
|||
'''
|
||||
raise NotImplementedError('clone has to be overwritten')
|
||||
|
||||
def _serialize(self):
|
||||
dictionary = {}
|
||||
dictionary['class'] = self.__class__.__name__
|
||||
dictionary['module'] = self.__class__.__module__
|
||||
dictionary['op_name'] = self.op_name
|
||||
dictionary['state'] = self.serialize()
|
||||
def _serialize_impl(self):
|
||||
dictionary = _serialize(self)
|
||||
return _py_dict_to_cntk_dict(dictionary)
|
||||
|
||||
@staticmethod
|
||||
def deserialize(inputs, name, state):
|
||||
'''
|
||||
A stub deserialize method for illustration purposes. User-defined functions
|
||||
need to provide their own implementation in order for CNTK to be able to
|
||||
reconstruct them when loading a model.
|
||||
|
||||
Args:
|
||||
inputs (list): a list of inputs to the function
|
||||
name (str): name of this function
|
||||
state (dict): a state dictionary generated by the corresponding
|
||||
:func:`~cntk.ops.functions.UserFunction.serialize` method.
|
||||
|
||||
Returns:
|
||||
An instance of the user-defined function.
|
||||
'''
|
||||
raise NotImplementedError('a stub method for illustration purposes.')
|
||||
|
||||
@property
|
||||
def op_name(self):
|
||||
'''
|
||||
Unique operation name of this user-defined function.
|
||||
This property defaults to '<module>.<class>', but can be overridden.
|
||||
'''
|
||||
return self.__class__._op_name()
|
||||
|
||||
def serialize(self):
|
||||
'''
|
||||
Generates a dictionary that captures the state of this user-defined function.
|
||||
|
@ -1418,3 +1462,7 @@ class UserFunction(Function):
|
|||
to be preserved in the model dictionary.
|
||||
'''
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def _op_name(cls):
|
||||
return cls.__module__ + '.' + cls.__name__
|
||||
|
|
|
@ -49,10 +49,8 @@ class MyPlus(UserFunction):
|
|||
input_gradients[input] = root_gradients
|
||||
|
||||
def serialize(self):
|
||||
internal_state = {}
|
||||
internal_state['forward_calls'] = self.forward_calls
|
||||
internal_state['backward_calls'] = self.backward_calls
|
||||
return internal_state
|
||||
return {'forward_calls' : self.forward_calls,
|
||||
'backward_calls' : self.backward_calls}
|
||||
|
||||
@staticmethod
|
||||
def deserialize(inputs, name, state):
|
||||
|
@ -61,6 +59,24 @@ class MyPlus(UserFunction):
|
|||
f.backward_calls = state['backward_calls']
|
||||
return f
|
||||
|
||||
|
||||
class MyPlusPlus(MyPlus):
|
||||
def __init__(self, inputs, name, state={}):
|
||||
super(MyPlusPlus, self).__init__(inputs[0], inputs[1], name=name+name)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
r1 = super(MyPlusPlus, self).forward(*args, **kwargs)
|
||||
r2 = super(MyPlusPlus, self).forward(*args, **kwargs)
|
||||
return None, r1[1] + r2[1]
|
||||
|
||||
def serialize(self):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def deserialize(*args):
|
||||
return MyPlusPlus(*args)
|
||||
|
||||
|
||||
def test_ext_eval_1():
|
||||
dim = 4
|
||||
p = parameter(shape=(dim,), init=10, name='p')
|
||||
|
@ -315,12 +331,14 @@ def test_ext_lambdafunc(tmpdir):
|
|||
|
||||
filepath = str(tmpdir / 'test_ext_lambdafunc.dat')
|
||||
z0.save(filepath)
|
||||
z = Function.load(filepath, udf_factory_callback_map =
|
||||
{
|
||||
'conditional_exec_lambda' :
|
||||
lambda x, *unused:
|
||||
LambdaFunc(x, when=lambda arg: np.sum(arg)>1, execute=cb.inc)
|
||||
})
|
||||
|
||||
Function.register_udf_deserialize_callback(
|
||||
'conditional_exec_lambda',
|
||||
lambda x, *unused:
|
||||
LambdaFunc(x, when=lambda arg: np.sum(arg)>1, execute=cb.inc))
|
||||
|
||||
z = Function.load(filepath)
|
||||
|
||||
momentum_time_constant = momentum_as_time_constant_schedule(1100)
|
||||
lr_per_sample = learning_rate_schedule(0.007, UnitType.sample)
|
||||
trainer = Trainer(z, (z+0, z+0), \
|
||||
|
@ -484,18 +502,136 @@ def test_udf_output_needs_no_gradient():
|
|||
assert np.allclose(result, [[[6., 8.], [10., 12.]]])
|
||||
|
||||
|
||||
def test_native_user_function():
|
||||
ops.register_native_user_function('NativeUserTimesFunction', 'Cntk.ExtensibilityExamples-' + C.__version__.rstrip('+'), 'CreateUserTimesFunction')
|
||||
def test_native_user_function(tmpdir):
|
||||
|
||||
if not cntk_py.is_native_user_function_registered('NativeUserTimesOp'):
|
||||
ops.register_native_user_function('NativeUserTimesOp', 'Cntk.ExtensibilityExamples-' + C.__version__.rstrip('+'), 'CreateUserTimesFunction')
|
||||
|
||||
dev = cpu()
|
||||
x = input((2))
|
||||
w = parameter((2, 2), init=np.asarray([[0.5, 2], [-0.5, 1.5]], dtype=np.float32), device=dev)
|
||||
attributes = {'param_rank' : 2, 'padding' : True}
|
||||
op = ops.native_user_function('NativeUserTimesFunction', [w, x], attributes, 'native_user_times_function')
|
||||
attributes = {
|
||||
'param_rank' : 2,
|
||||
'padding' : True,
|
||||
'none': None,
|
||||
'nested lists' : [[1,2,3], [4,5,6]],
|
||||
'string' : 'string',
|
||||
'some data' : np.arange(1, 10, dtype=np.float32).reshape((3,3))
|
||||
}
|
||||
|
||||
def verify_attributes(udf):
|
||||
for k,v in attributes.items():
|
||||
if not isinstance(v, np.ndarray):
|
||||
assert udf.attributes[k] == v
|
||||
else:
|
||||
assert (udf.attributes[k] == v).all()
|
||||
|
||||
op = ops.native_user_function('NativeUserTimesOp', [w, x], attributes, 'native_user_times_function')
|
||||
|
||||
verify_attributes(op.owner)
|
||||
|
||||
filepath = str(tmpdir / 'test_native_user_function.dat')
|
||||
op.save(filepath)
|
||||
|
||||
op_reloaded = Function.load(filepath, device=dev)
|
||||
x_data = NDArrayView.from_dense(np.asarray([[0.1, 0.2], [-0.1, 0.3]], dtype=np.float32), device=dev)
|
||||
result = op.eval({x : x_data}, device=dev)
|
||||
result = op_reloaded.eval({op_reloaded.arguments[0] : x_data}, device=dev)
|
||||
|
||||
assert np.allclose(result, [[-0.05, 0.5], [-0.2, 0.25]])
|
||||
|
||||
native_times_primitive = op.find_by_name('native_user_times_function')
|
||||
assert native_times_primitive.attributes == attributes
|
||||
|
||||
native_times_primitive = op_reloaded.find_by_name('native_user_times_function')
|
||||
|
||||
verify_attributes(native_times_primitive)
|
||||
|
||||
|
||||
def build_test_function():
|
||||
dev = cpu()
|
||||
w_value = np.asarray([[0.5, 2], [-0.5, 1.5]]).astype(np.float32)
|
||||
c1_value = 2.718
|
||||
c2_value = -3.141
|
||||
|
||||
if not cntk_py.is_native_user_function_registered('NativeUserTimesOp'):
|
||||
ops.register_native_user_function('NativeUserTimesOp', 'Cntk.ExtensibilityExamples-' + C.__version__.rstrip('+'), 'CreateUserTimesFunction')
|
||||
|
||||
x = input((2))
|
||||
|
||||
w = parameter((2, 2), init=w_value, device=dev)
|
||||
|
||||
mp = MyPlus(x, constant(c1_value))
|
||||
op = user_function(MyPlus(x, constant(c1_value)))
|
||||
op = ops.native_user_function('NativeUserTimesOp', [w, op], user_function_instance_name='my_times')
|
||||
|
||||
return dev, w_value, c1_value, c2_value, user_function(MyPlus(op, constant(c2_value)))
|
||||
|
||||
def test_both_flavors_of_user_functions(tmpdir):
|
||||
dev, w_value, c1_value, c2_value, op = build_test_function()
|
||||
|
||||
filepath = str(tmpdir / 'test_native_user_function.dat')
|
||||
op.save(filepath)
|
||||
op_reloaded = Function.load(filepath, device=dev)
|
||||
|
||||
np.random.seed(1)
|
||||
|
||||
for i in range(5):
|
||||
x_value = np.random.random((2,2)).astype(np.float32)
|
||||
x_data = NDArrayView.from_dense(x_value, device=dev)
|
||||
result = op_reloaded.eval({op_reloaded.arguments[0] : x_data}, device=dev)
|
||||
expected = np.matmul((x_value + c1_value), w_value) + c2_value
|
||||
assert np.allclose(result, expected)
|
||||
|
||||
def test_udf_checkpointing(tmpdir):
|
||||
dev, w_value, c1_value, c2_value, op = build_test_function()
|
||||
|
||||
label = constant(np.asarray([[1, 2], [3,4]]).astype(np.float32))
|
||||
|
||||
loss = cross_entropy_with_softmax(op, label)
|
||||
eval_error = classification_error(op, label)
|
||||
|
||||
lr_schedule = learning_rate_schedule(0.5, UnitType.minibatch)
|
||||
learner = sgd(op.parameters, lr_schedule)
|
||||
trainer = Trainer(op, (loss, eval_error), [learner])
|
||||
|
||||
trainer.train_minibatch({op.arguments[0] : np.random.random((2,2)).astype(np.float32)}, device=dev)
|
||||
|
||||
filepath = str(tmpdir /'test_checkpointing.out')
|
||||
|
||||
trainer.save_checkpoint(filepath, external_state={'test':'test'})
|
||||
|
||||
d = cntk_py.Dictionary.load(filepath)
|
||||
assert len(d.keys()) != 0
|
||||
|
||||
def test_override_deserialize(tmpdir):
|
||||
dev, w_value, c1_value, c2_value, op = build_test_function()
|
||||
|
||||
filepath = str(tmpdir / 'test_override_deserialize.dat')
|
||||
op.save(filepath)
|
||||
|
||||
Function.register_udf_deserialize_callback(
|
||||
MyPlus._op_name(), lambda *x : MyPlusPlus(*x))
|
||||
|
||||
op_reloaded = Function.load(filepath, device=dev)
|
||||
|
||||
np.random.seed(1)
|
||||
|
||||
for i in range(5):
|
||||
x_value = np.random.random((2,2)).astype(np.float32)
|
||||
x_data = NDArrayView.from_dense(x_value, device=dev)
|
||||
result = op_reloaded.eval({op_reloaded.arguments[0] : x_data}, device=dev)
|
||||
expected = 2 *(np.matmul(2 * (x_value + c1_value), w_value) + c2_value)
|
||||
assert np.allclose(result, expected)
|
||||
|
||||
def test_override_serialize(tmpdir):
|
||||
dev=cpu()
|
||||
a,b = 1.2322341, -0.29084;
|
||||
op = MyPlusPlus([constant(a), constant(b)], '++');
|
||||
op = MyPlusPlus([op, op], '+++')
|
||||
op = MyPlusPlus([op, op], '++++')
|
||||
op = user_function(op)
|
||||
result1 = op.eval({}, device=dev)
|
||||
|
||||
filepath = str(tmpdir / 'test_udf_with_renamed_deserialize.dat')
|
||||
op.save(filepath)
|
||||
|
||||
op_reloaded = Function.load(filepath, device=dev)
|
||||
|
||||
assert result1 == op_reloaded.eval({}, device=dev)
|
|
@ -14,6 +14,7 @@ Implementing a custom operator in pure Python is simple matter of
|
|||
- implementing ``forward()`` and ``backward()``, whose signatures dependent on the number of inputs and outputs
|
||||
- specifying the outputs' shape, data type and dynamic axes in
|
||||
``infer_outputs()``
|
||||
- providing a static ``deserialize()`` method to inflate previously saved function
|
||||
|
||||
In the simplest case, just only one input and output, ``forward()`` takes an
|
||||
argument and returns a tuple of a state and the result. The state can be used to
|
||||
|
@ -46,6 +47,10 @@ tuple, strings, etc.)::
|
|||
return [output_variable(self.inputs[0].shape, self.inputs[0].dtype,
|
||||
self.inputs[0].dynamic_axes)]
|
||||
|
||||
@staticmethod
|
||||
def deserialize(inputs, name, state):
|
||||
return = MySigmoid(inputs[0], name)
|
||||
|
||||
This can now be used as a normal operator like::
|
||||
|
||||
from cntk import user_function
|
||||
|
@ -61,13 +66,17 @@ In case, the operator is initialized with multiple inputs, ``forward()`` 's
|
|||
class MyPlus(UserFunction):
|
||||
def __init__(self, arg1, arg2, name='f1'):
|
||||
super(MyPlus, self).__init__([arg1, arg2], name=name)
|
||||
self.forward_calls = 0
|
||||
self.backward_calls = 0
|
||||
|
||||
def forward(self, arguments, device=None, outputs_to_retain=None):
|
||||
# No state needs to be passed to backward() so we just
|
||||
# pass None
|
||||
self.forward_calls += 1
|
||||
return None, arguments[0] + arguments[1]
|
||||
|
||||
def backward(self, state, root_gradients):
|
||||
self.backward_calls += 1
|
||||
return root_gradients
|
||||
|
||||
def infer_outputs(self):
|
||||
|
@ -76,6 +85,17 @@ In case, the operator is initialized with multiple inputs, ``forward()`` 's
|
|||
# result would actually look like (considering broadcasting, etc.).
|
||||
return [output_variable(self.inputs[0].shape, self.inputs[0].dtype, self.inputs[0].dynamic_axes)]
|
||||
|
||||
def serialize(self):
|
||||
return {'forward_calls' : self.forward_calls,
|
||||
'backward_calls' : self.backward_calls}
|
||||
|
||||
@staticmethod
|
||||
def deserialize(inputs, name, state):
|
||||
f = MyPlus(inputs[0], inputs[1], name)
|
||||
f.forward_calls = state['forward_calls']
|
||||
f.backward_calls = state['backward_calls']
|
||||
return f
|
||||
|
||||
If the UserFunction has more than one input, ``backward()`` is invoked
|
||||
with an additional ``variables`` argument, which contains a dictionary of
|
||||
Variable to the gradient data, whose values have to be set with the proper
|
||||
|
@ -101,6 +121,13 @@ have to be set with the proper data.
|
|||
In addition, ``root_gradient`` in ``backward()`` is a dictionary of Variable to the
|
||||
root_gradient.
|
||||
|
||||
``deserialize()`` is invoked by CNTK to reconstruct a previously saved function. It should
|
||||
have the same signature as :func:`~cntk.ops.functions.UserFunction.deserialize` method.
|
||||
In case of a stateless function, it simply needs to invoke the constructor and return an
|
||||
instance of the user function. However, if the function is stateful and overrides
|
||||
:func:`~cntk.ops.functions.UserFunction.serialize` method, ``deserialize()`` also needs to
|
||||
properly restore the function state.
|
||||
|
||||
Using user functions for debugging
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче