diff --git a/Examples/Extensibility/CPP/Main.cpp b/Examples/Extensibility/CPP/Main.cpp index d7911b070..da12e4549 100644 --- a/Examples/Extensibility/CPP/Main.cpp +++ b/Examples/Extensibility/CPP/Main.cpp @@ -16,41 +16,71 @@ void UserTimesFunctionExample() auto x = InputVariable(NDShape({ inDim }), DataType::Float, { Axis::DefaultBatchAxis() }); auto userDefinedTimes = UserTimesFunction::Create(W, x, L"UserDefinedTimes"); - size_t batchSize = 3; - std::vector inputData(inDim * batchSize); - for (size_t i = 0; i < inputData.size(); ++i) - inputData[i] = (float)rand() / RAND_MAX; - auto inputDataValue = Value::CreateBatch(x.Shape(), inputData, device); + auto compareWithBuiltInTimes = [device, outDim, inDim](FunctionPtr times) { + size_t batchSize = 3; + std::vector inputData(inDim * batchSize); + for (size_t i = 0; i < inputData.size(); ++i) + inputData[i] = (float)rand() / RAND_MAX; - std::vector rootGradientData(outDim * batchSize, 1); - auto rootGradientValue = Value::CreateBatch(userDefinedTimes->Output().Shape(), rootGradientData, device); + auto input = times->Arguments()[0]; + auto inputDataValue = Value::CreateBatch(input.Shape(), inputData, device); - std::unordered_map outputValues = { { userDefinedTimes->Output(), nullptr } }; - auto backPropState = userDefinedTimes->Forward({ { x, inputDataValue } }, outputValues, device, { userDefinedTimes->Output() }); + std::vector rootGradientData(outDim * batchSize, 1); + auto rootGradientValue = Value::CreateBatch(times->Output().Shape(), rootGradientData, device); - std::unordered_map inputGradientValues = { { W, nullptr } }; - userDefinedTimes->Backward(backPropState, { { userDefinedTimes->Output(), rootGradientValue } }, inputGradientValues); - auto userDefinedTimesOutputValue = outputValues[userDefinedTimes->Output()]; - auto userDefinedTimesInputGradientValue = inputGradientValues[W]; + std::unordered_map outputValues = { { times->Output(), nullptr } }; + auto backPropState = times->Forward({ { input, inputDataValue } }, outputValues, device, { times->Output() }); - // Compare against the CNTK built-in implementation - auto builtInTimes = Times(W, x, L"BuiltInTimes"); - outputValues = { { builtInTimes->Output(), nullptr } }; - backPropState = builtInTimes->Forward({ { x, inputDataValue } }, outputValues, device, { builtInTimes->Output() }); - inputGradientValues = { { W, nullptr } }; - builtInTimes->Backward(backPropState, { { builtInTimes->Output(), rootGradientValue } }, inputGradientValues); - auto builtInTimesOutputValue = outputValues[builtInTimes->Output()]; - auto builtInTimesInputGradientValue = inputGradientValues[W]; - const double relativeTolerance = 0.001f; - const double absoluteTolerance = 0.000001f; + auto parameter = times->Parameters()[0]; - if (!Internal::AreEqual(*userDefinedTimesOutputValue, *builtInTimesOutputValue, relativeTolerance, absoluteTolerance)) - std::runtime_error("UserTimesOp's Forward result does not match built-in result"); + std::unordered_map inputGradientValues = { { parameter, nullptr } }; + times->Backward(backPropState, { { times->Output(), rootGradientValue } }, inputGradientValues); + auto userDefinedTimesOutputValue = outputValues[times->Output()]; + auto userDefinedTimesInputGradientValue = inputGradientValues[parameter]; - if (!Internal::AreEqual(*userDefinedTimesInputGradientValue, *builtInTimesInputGradientValue, relativeTolerance, absoluteTolerance)) - std::runtime_error("UserTimesOp's Forward result does not match built-in result"); + // Compare against the CNTK built-in implementation + auto builtInTimes = Times(parameter, input, L"BuiltInTimes"); + outputValues = { { builtInTimes->Output(), nullptr } }; + backPropState = builtInTimes->Forward({ { input, inputDataValue } }, outputValues, device, { builtInTimes->Output() }); + inputGradientValues = { { parameter, nullptr } }; + builtInTimes->Backward(backPropState, { { builtInTimes->Output(), rootGradientValue } }, inputGradientValues); + auto builtInTimesOutputValue = outputValues[builtInTimes->Output()]; + auto builtInTimesInputGradientValue = inputGradientValues[parameter]; + + const double relativeTolerance = 0.001f; + const double absoluteTolerance = 0.000001f; + + if (!Internal::AreEqual(*userDefinedTimesOutputValue, *builtInTimesOutputValue, relativeTolerance, absoluteTolerance)) + std::runtime_error("UserTimesOp's Forward result does not match built-in result"); + + if (!Internal::AreEqual(*userDefinedTimesInputGradientValue, *builtInTimesInputGradientValue, relativeTolerance, absoluteTolerance)) + std::runtime_error("UserTimesOp's Forward result does not match built-in result"); + + }; + + compareWithBuiltInTimes(userDefinedTimes); + + auto version = std::string(CNTK_COMPONENT_VERSION); + std::wstring wversion(version.begin(), version.end()); + Function::RegisterNativeUserFunction(L"NativeUserTimesOp", L"Cntk.ExtensibilityExamples-" + wversion, L"CreateUserTimesFunction"); + + userDefinedTimes->Save(L"UserTimesFunctionExample.model"); + + auto userDefinedTimes_reloaded_1 = Function::Load(L"UserTimesFunctionExample.model", device); + + compareWithBuiltInTimes(userDefinedTimes_reloaded_1); + + Function::RegisterUDFDeserializeCallback(L"NativeUserTimesOp", [](const std::vector& inputs, + const std::wstring& name, + const Dictionary& state) { + return UserTimesFunction::Create(inputs[0], inputs[1], state, name); + }); + + auto userDefinedTimes_reloaded_2 = Function::Load(L"UserTimesFunctionExample.model", device); + + compareWithBuiltInTimes(userDefinedTimes_reloaded_2); } #pragma warning(pop) diff --git a/Examples/Extensibility/CPP/UserMatrixMultiplicationOp.h b/Examples/Extensibility/CPP/UserMatrixMultiplicationOp.h index 1efb85fb3..488356938 100644 --- a/Examples/Extensibility/CPP/UserMatrixMultiplicationOp.h +++ b/Examples/Extensibility/CPP/UserMatrixMultiplicationOp.h @@ -21,7 +21,7 @@ public: } UserTimesFunction(const Variable& leftOperand, const Variable& rightOperand, const Dictionary& attributes, const std::wstring& name) - : Function({ leftOperand, rightOperand }, Dictionary(attributes), name) + : Function({ leftOperand, rightOperand }, attributes, name) {} private: @@ -116,11 +116,10 @@ private: const std::wstring& OpName() const override { - static const std::wstring opName = L"UserTimesOp"; + static const std::wstring opName = L"NativeUserTimesOp"; return opName; } - Dictionary Serialize() const override { NOT_IMPLEMENTED; } size_t CurrentVersion() const override { NOT_IMPLEMENTED; } void InferOutputs(std::vector& outputs) override diff --git a/Makefile b/Makefile index c456a895c..ff2f384ff 100644 --- a/Makefile +++ b/Makefile @@ -469,6 +469,7 @@ CNTKLIBRARY_COMMON_SRC =\ $(SOURCEDIR)/CNTKv2LibraryDll/Function.cpp \ $(SOURCEDIR)/CNTKv2LibraryDll/PrimitiveFunction.cpp \ $(SOURCEDIR)/CNTKv2LibraryDll/CompositeFunction.cpp \ + $(SOURCEDIR)/CNTKv2LibraryDll/UserDefinedFunction.cpp \ $(SOURCEDIR)/CNTKv2LibraryDll/NDArrayView.cpp \ $(SOURCEDIR)/CNTKv2LibraryDll/NDMask.cpp \ $(SOURCEDIR)/CNTKv2LibraryDll/Trainer.cpp \ diff --git a/Source/CNTKv2LibraryDll/API/CNTKLibrary.h b/Source/CNTKv2LibraryDll/API/CNTKLibrary.h index 70450f592..c6503b66e 100644 --- a/Source/CNTKv2LibraryDll/API/CNTKLibrary.h +++ b/Source/CNTKv2LibraryDll/API/CNTKLibrary.h @@ -2885,6 +2885,8 @@ namespace CNTK const std::wstring& /*name*/, const Dictionary& /*dictionary*/)> UDFDeserializeCallback; + typedef std::shared_ptr UDFDeserializeCallbackPtr; + static auto NoOp = [] (const std::vector&, const std::wstring&, const Dictionary&) { return nullptr; @@ -2907,6 +2909,7 @@ namespace CNTK friend class Trainer; friend Variable GetCorrespondingOutputVariableFromClone(const Variable&, const FunctionPtr&, const FunctionPtr&); + friend bool Internal::IsNativeUserFunctionRegistered(const std::wstring& uniqueOpName); public: @@ -2975,6 +2978,20 @@ namespace CNTK /// CNTK_API virtual void InferOutputs(std::vector& outputs) = 0; + /// + /// Returns the name of the module (dll/so) containing this function. For native functions registered through + /// a call to 'RegisterNativeUserFunction', unless overridden, this method return the value of the 'moduleName' + /// argument. + /// + CNTK_API virtual std::wstring ModuleName() const; + + /// + /// Returns the name of the method which should be invoked to deserialize this function. For native functions + /// registered through a call to 'RegisterNativeUserFunction', unless overridden, this method return the value + /// of the 'factoryMethodName' argument. If overridden, it must have the same signature as the factory method. + /// + CNTK_API virtual std::wstring DeserializeMethodName() const; + public: // Optional overrides @@ -2997,7 +3014,7 @@ namespace CNTK /// /// Generates a dictionary that captures the state of the Function graph underlying this Function. /// - CNTK_API virtual Dictionary Serialize() const override { return Dictionary(); } + CNTK_API virtual Dictionary Serialize() const override { return Attributes(); } /// /// Creates a clone of this Function instance, using the specified 'inputs' that are inputs of the clone to be constructed. @@ -3052,8 +3069,7 @@ namespace CNTK /// if deserializer was omitted). If there are no user defined functions in the model, deserializer is ignored. /// CNTK_API static FunctionPtr Deserialize(const Dictionary& dictionary, - const ::CNTK::DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice(), - const UDFDeserializeCallback& deserializer = NoOp); + const ::CNTK::DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice()); public: /// @@ -3250,44 +3266,19 @@ namespace CNTK /// Load a Function from a model file /// CNTK_API static FunctionPtr Load(const std::wstring& filepath, - const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice(), - const UDFDeserializeCallback& deserializer = NoOp); + const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice()); /// /// Load a Function from a memory buffer /// CNTK_API static FunctionPtr Load(const char* buffer, size_t length, - const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice(), - const UDFDeserializeCallback& deserializer = NoOp); + const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice()); /// /// Load a Function from an istream. The legacy V1 model is not supported. /// CNTK_API static FunctionPtr Load(std::istream& inputStream, - const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice(), - const UDFDeserializeCallback& deserializer = NoOp); - -#ifdef SWIG // for Python interop (adds callback wrapper) - static CNTK::FunctionPtr Load(const std::wstring& filepath, - const CNTK::DeviceDescriptor& computeDevice, - const CNTK::Internal::UDFDeserializeCallbackWrapper& wrapper) - { - using namespace std::placeholders; - UDFDeserializeCallback callback = std::bind(&CNTK::Internal::UDFDeserializeCallbackWrapper::operator(), - &wrapper, _1, _2, _3); - return CNTK::Function::Load(filepath, computeDevice, callback); - } - - static CNTK::FunctionPtr Load(const char* buffer, size_t length, - const CNTK::DeviceDescriptor& computeDevice, - const CNTK::Internal::UDFDeserializeCallbackWrapper& wrapper) - { - using namespace std::placeholders; - UDFDeserializeCallback callback = std::bind(&CNTK::Internal::UDFDeserializeCallbackWrapper::operator(), - &wrapper, _1, _2, _3); - return CNTK::Function::Load(buffer, length, computeDevice, callback); - } -#endif + const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice()); /// /// Prints the entire graph underlying this Function to stderr @@ -3317,6 +3308,18 @@ namespace CNTK /// CNTK_API static FunctionPtr NativeUserFunction(const std::wstring& opName, const std::vector& operands, const Dictionary& functionConfig, const std::wstring& userFunctionInstanceName = L""); + /// + /// Register a callback function to be invoked when deserializing a user-defined Function with the corresponding op name. + /// When loading a model, CNTK will try to automatically reconstruct user-defined Functions (for native functions, CNTK will + /// invoke the same factory method, the Function op name was registered with). This method allows to override + /// default user-defined Function deserialization behavior by specifying an op name and the corresponding callback that should be invoked + /// to inflate the Function object. + /// + CNTK_API static void RegisterUDFDeserializeCallback(const std::wstring& uniqueOpName, const UDFDeserializeCallback& deserializer); + + static UDFDeserializeCallbackPtr GetUDFDeserializeCallback(const std::wstring& uniqueOpName); + + protected: static bool IsArgument(const Variable& var) { @@ -3324,9 +3327,10 @@ namespace CNTK } /// - /// Protected constructor for derived 'Function' types to specify the actual input and output variables for the (primitive) Function instance. + /// Protected constructors for derived user-defined 'Function' types to specify the actual input and output variables for the (primitive) Function instance. /// - CNTK_API Function(const std::vector& inputs, Dictionary&& functionConfig, const std::wstring& name = L"", const std::wstring& uid = Internal::GenerateUid(L"UserDefinedFunction")); + CNTK_API Function(const std::vector& inputs, const Dictionary& functionConfig, const std::wstring& name = L""); + CNTK_API Function(const std::vector& inputs, const std::wstring& name = L""); template static void PreorderTraverseFunctions(const FunctionPtr& rootFunction, const FunctionType& functor, bool traverseInsideBlockFunction = false) @@ -3420,14 +3424,11 @@ namespace CNTK // Disallow copy and move construction and assignment Function(const Function&) = delete; Function(Function&&) = delete; Function& operator=(const Function&) = delete; Function& operator=(Function&&) = delete; - public: - CNTK_API Function(const std::vector& inputs, const std::wstring& name = L""); - private: static UserFunctionFactoryPtr s_userFunctionFactory; private: - CNTK_API Function(const std::vector& inputs, Dictionary&& functionConfig, const FunctionPtr& rootFunction, const std::wstring& name, const std::wstring& uid); + Function(const std::vector& inputs, const Dictionary& functionConfig, const FunctionPtr& rootFunction, const std::wstring& name, const std::wstring& uid); std::vector m_inputs; std::once_flag m_outputsInitFlag; @@ -3437,6 +3438,22 @@ namespace CNTK std::wstring m_name; std::wstring m_uid; Dictionary m_attributes; + +#ifdef SWIG + public: + void SetNative(bool native) { m_native = native; } +#endif + + private: + bool IsNative() const { return m_native; } + + bool m_native = true; + + Dictionary SerializeNativeImpl() const; + + static FunctionPtr DeserializeNativeImpl(const std::vector& inputs, const std::wstring& name, const Dictionary& dict); + + static const size_t s_serializationVersion = 1; }; /// diff --git a/Source/CNTKv2LibraryDll/API/CNTKLibraryInternals.h b/Source/CNTKv2LibraryDll/API/CNTKLibraryInternals.h index b9e285572..3c66a6a78 100644 --- a/Source/CNTKv2LibraryDll/API/CNTKLibraryInternals.h +++ b/Source/CNTKv2LibraryDll/API/CNTKLibraryInternals.h @@ -376,8 +376,6 @@ namespace CNTK std::wstring m_fileName; }; - -#ifdef SWIG // SWIG callback wrapper for the UDF deserialization. class UDFDeserializeCallbackWrapper { @@ -385,8 +383,13 @@ namespace CNTK virtual FunctionPtr operator()(const std::vector&, const std::wstring&, const Dictionary&) const = 0; virtual ~UDFDeserializeCallbackWrapper() = default; }; -#endif + typedef std::shared_ptr UDFDeserializeCallbackWrapperPtr; + + CNTK_API void RegisterUDFDeserializeCallbackWrapper(UDFDeserializeCallbackWrapperPtr callbackPtr); + + + CNTK_API bool IsNativeUserFunctionRegistered(const std::wstring& uniqueOpName); } // Forward-declare test fixtures, so that they can be used as friends. diff --git a/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj b/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj index 1ee5886ec..59560f34d 100644 --- a/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj +++ b/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj @@ -137,6 +137,7 @@ + @@ -175,6 +176,7 @@ + diff --git a/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj.filters b/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj.filters index 7eec2a5c1..990693439 100644 --- a/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj.filters +++ b/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj.filters @@ -36,6 +36,7 @@ + @@ -67,6 +68,7 @@ + diff --git a/Source/CNTKv2LibraryDll/CompositeFunction.cpp b/Source/CNTKv2LibraryDll/CompositeFunction.cpp index 241abbe80..b04ff4d93 100755 --- a/Source/CNTKv2LibraryDll/CompositeFunction.cpp +++ b/Source/CNTKv2LibraryDll/CompositeFunction.cpp @@ -23,6 +23,7 @@ #include "BlockFunction.h" #include "SpecialPurposeNodes.h" #include "SequenceReshapeNodes.h" +#include "UserDefinedFunction.h" using namespace Microsoft::MSR::CNTK; @@ -58,7 +59,7 @@ namespace CNTK for (auto& function : m_allPrimitiveFunctions) { auto primitiveFunction = dynamic_cast(function.get()); - if (!primitiveFunction->IsStateful()) + if (!primitiveFunction || !primitiveFunction->IsStateful()) continue; // TODO: same for BatchNorm @@ -85,7 +86,7 @@ namespace CNTK for (auto& function : m_allPrimitiveFunctions) { auto primitiveFunction = dynamic_cast(function.get()); - if (!primitiveFunction->IsStateful()) + if (!primitiveFunction || !primitiveFunction->IsStateful()) continue; // TODO: same for BatchNorm @@ -224,7 +225,7 @@ namespace CNTK return composite; } - /*static*/ FunctionPtr CompositeFunction::Deserialize(const Dictionary& dict, const CNTK::DeviceDescriptor& device, const UDFDeserializeCallback& callback) + /*static*/ FunctionPtr CompositeFunction::Deserialize(const Dictionary& dict, const CNTK::DeviceDescriptor& device) { static const vector s_requiredDictionaryKeys = { inputsKey, functionsKey }; @@ -254,7 +255,7 @@ namespace CNTK { auto functionDict = dictionaryValue.Value(); FunctionPtr root = UDFUtils::IsUDF(functionDict) ? - UDFUtils::Deserialize(functionDict, uidToInputMap, device, callback) : + UDFUtils::Deserialize(functionDict, uidToInputMap, device) : PrimitiveFunction::Deserialize(functionDict, uidToInputMap, allPrimitiveFunctions, allPlaceholderReplacements, device); allPrimitiveFunctions.insert(root); @@ -303,7 +304,7 @@ namespace CNTK for (auto& function : functions) { auto primitiveFunction = dynamic_cast(function.get()); - if (!primitiveFunction->IsStateful()) + if (!primitiveFunction || !primitiveFunction->IsStateful()) continue; if (stateDictionary.Contains(primitiveFunction->Uid())) @@ -340,7 +341,7 @@ namespace CNTK vector uids; PreorderTraverseFunctions(function.RootFunction(), [&uids](const FunctionPtr& funcPtr) { auto primitiveFunction = dynamic_cast(funcPtr.get()); - if (primitiveFunction->IsStateful()) + if (primitiveFunction && primitiveFunction->IsStateful()) { uids.push_back(funcPtr->Uid()); } @@ -384,7 +385,7 @@ namespace CNTK for (const auto& function : m_allPrimitiveFunctions) { auto primitiveFunction = dynamic_cast(function.get()); - if (!primitiveFunction->IsStateful()) + if (!primitiveFunction || !primitiveFunction->IsStateful()) continue; auto functionState = state[primitiveFunction->Uid()].Value(); diff --git a/Source/CNTKv2LibraryDll/CompositeFunction.h b/Source/CNTKv2LibraryDll/CompositeFunction.h index 167352c87..a100f0343 100644 --- a/Source/CNTKv2LibraryDll/CompositeFunction.h +++ b/Source/CNTKv2LibraryDll/CompositeFunction.h @@ -119,7 +119,7 @@ namespace CNTK const std::unordered_map& allPlaceholderReplacements, const CNTK::DeviceDescriptor& device); - static FunctionPtr Deserialize(const Dictionary& dictionary, const CNTK::DeviceDescriptor& device, const UDFDeserializeCallback& callback); + static FunctionPtr Deserialize(const Dictionary& dictionary, const CNTK::DeviceDescriptor& device); virtual const std::wstring& OpName() const override { diff --git a/Source/CNTKv2LibraryDll/Function.cpp b/Source/CNTKv2LibraryDll/Function.cpp index 5c05c401f..951d42670 100755 --- a/Source/CNTKv2LibraryDll/Function.cpp +++ b/Source/CNTKv2LibraryDll/Function.cpp @@ -27,6 +27,30 @@ namespace CNTK return AsComposite(s_userFunctionFactory->CreateInstance(opName, operands, functionConfig, userFunctionInstanceName), userFunctionInstanceName); } + bool Internal::IsNativeUserFunctionRegistered(const std::wstring& uniqueOpName) + { + return Function::s_userFunctionFactory->IsRegistered(uniqueOpName); + } + + static std::unordered_map udfCallbackMap; + static std::mutex udfCallbackMapMutex; + + /*static*/ void Function::RegisterUDFDeserializeCallback(const std::wstring& uniqueOpName, const UDFDeserializeCallback& deserializer) + { + std::unique_lock lock(udfCallbackMapMutex); + auto result = udfCallbackMap.insert({ uniqueOpName, make_shared(deserializer) }); + if (!result.second) + InvalidArgument("A callback for the UserFunction with op name '%S' has already been registered.", uniqueOpName.c_str()); + } + + /*static*/ UDFDeserializeCallbackPtr Function::GetUDFDeserializeCallback(const std::wstring& uniqueOpName) + { + std::unique_lock lock(udfCallbackMapMutex); + if (udfCallbackMap.find(uniqueOpName) == udfCallbackMap.end()) + return nullptr; + return udfCallbackMap.at(uniqueOpName); + } + std::vector& Function::InitOutputs() { std::call_once(m_outputsInitFlag, [this]() { @@ -85,10 +109,6 @@ namespace CNTK return std::shared_ptr>(new std::vector(std::move(inputs)), [](std::vector* ptr) { delete ptr; }); } - Function::Function(const std::vector& inputs, Dictionary&& functionConfig, const std::wstring& name, const std::wstring& uid) - : Function(inputs, std::move(functionConfig), nullptr, name, uid) - {} - std::shared_ptr> Function::OutputsImpl() const { std::vector outputs; @@ -99,11 +119,7 @@ namespace CNTK return std::shared_ptr>(new std::vector(std::move(outputs)), [](std::vector* ptr) { delete ptr; }); } - Function::Function(const std::vector& inputs, const std::wstring& name) - : Function(inputs, Dictionary(), name) - {} - - Function::Function(const std::vector& inputs, Dictionary&& functionConfig, const FunctionPtr& rootFunction, const std::wstring& name, const std::wstring& uid) + Function::Function(const std::vector& inputs, const Dictionary& functionConfig, const FunctionPtr& rootFunction, const std::wstring& name, const std::wstring& uid) : m_rootFunction(rootFunction), m_name(name), m_uid(uid), m_attributes(std::move(functionConfig)) { for (auto inputVar : inputs) @@ -438,14 +454,14 @@ namespace CNTK stream->flush(); } - /*static*/ FunctionPtr Function::Load(const std::wstring& filepath, const DeviceDescriptor& computeDevice, const UDFDeserializeCallback& callback) + /*static*/ FunctionPtr Function::Load(const std::wstring& filepath, const DeviceDescriptor& computeDevice) { auto stream = GetFstream(filepath, true); if (!Internal::IsLegacyModel(*stream)) { Dictionary model; *stream >> model; - return Function::Deserialize(model, computeDevice, callback); + return Function::Deserialize(model, computeDevice); } else { @@ -453,7 +469,7 @@ namespace CNTK } } - /*static*/ FunctionPtr Function::Load(const char *buffer, size_t length, const DeviceDescriptor& computeDevice, const UDFDeserializeCallback& callback) + /*static*/ FunctionPtr Function::Load(const char *buffer, size_t length, const DeviceDescriptor& computeDevice) { if ((buffer == nullptr) || (length <= 0)) InvalidArgument("The model buffer should not be null and its length should be greater than 0"); @@ -474,15 +490,15 @@ namespace CNTK modelStreamBuffer buf(buffer, length); std::istream modelStream(&buf); - return Load(modelStream, computeDevice, callback); + return Load(modelStream, computeDevice); } } - /*static*/ FunctionPtr Function::Load(std::istream& inputStream, const DeviceDescriptor& computeDevice, const UDFDeserializeCallback& callback) + /*static*/ FunctionPtr Function::Load(std::istream& inputStream, const DeviceDescriptor& computeDevice) { Dictionary model; inputStream >> model; - return Function::Deserialize(model, computeDevice, callback); + return Function::Deserialize(model, computeDevice); } void Function::Restore(const std::wstring& filepath) @@ -904,9 +920,9 @@ namespace CNTK compositeFunction->CopyState(*restoredCompositeFunction); } - /*static*/ FunctionPtr Function::Deserialize(const Dictionary& modelDictionary, const CNTK::DeviceDescriptor& device, const UDFDeserializeCallback& callback) + /*static*/ FunctionPtr Function::Deserialize(const Dictionary& modelDictionary, const CNTK::DeviceDescriptor& device) { - return CompositeFunction::Deserialize(modelDictionary, device, callback); + return CompositeFunction::Deserialize(modelDictionary, device); } void Function::PrintGraph() const diff --git a/Source/CNTKv2LibraryDll/PrimitiveFunction.cpp b/Source/CNTKv2LibraryDll/PrimitiveFunction.cpp index 8fabd6282..f1bef6cff 100644 --- a/Source/CNTKv2LibraryDll/PrimitiveFunction.cpp +++ b/Source/CNTKv2LibraryDll/PrimitiveFunction.cpp @@ -21,6 +21,7 @@ #include "ConvolveGeometry.h" #include "ConvolutionalNodes.h" #include "Variable.h" +#include "UserFunctionFactory.h" using namespace Microsoft::MSR::CNTK; @@ -963,7 +964,7 @@ namespace CNTK } } - static vector GetInputUids(const Function& f) + vector GetInputUids(const Function& f) { auto inputs = f.Inputs(); vector inputUids; @@ -975,7 +976,7 @@ namespace CNTK return inputUids; } - static Dictionary SerializeCommonAttributes(const Function& f, size_t version, const wstring& functionType) + Dictionary SerializeCommonFunctionAttributes(const Function& f, size_t version, const wstring& functionType) { Dictionary dict; dict[versionKey] = version; @@ -991,7 +992,7 @@ namespace CNTK /*virtual*/ Dictionary PrimitiveFunction::Serialize() const { - Dictionary dict = SerializeCommonAttributes(*this, CurrentVersion(), s_primitiveFunctionTypeValue); + Dictionary dict = SerializeCommonFunctionAttributes(*this, CurrentVersion(), s_primitiveFunctionTypeValue); dict[opKey] = static_cast(m_op); dict[attributesKey] = Attributes(); @@ -1018,7 +1019,7 @@ namespace CNTK return dict; } - static std::vector GetInputVariables(const Dictionary& dict, const unordered_map& uidToVariableMap, size_t currentSerializationVersion) + std::vector GetInputVariables(const Dictionary& dict, const std::unordered_map& uidToVariableMap, size_t currentSerializationVersion) { const auto& inputUids = dict[inputsKey].Value>(); @@ -1246,54 +1247,4 @@ namespace CNTK return dummyOutputVariable.Shape(); } - - static const std::wstring s_userDefinedFunctionTypeValue = L"UserDefinedFunction"; - - /*static*/ bool UDFUtils::IsUDF(const FunctionPtr& f) - { - return (dynamic_cast(f.get()) == nullptr); - } - - /*static*/ bool UDFUtils::IsUDF(const Dictionary& dict) - { - return (dict.Contains(typeKey) && dict[typeKey].Value() == s_userDefinedFunctionTypeValue); - } - - /*static*/ Dictionary UDFUtils::Serialize(const FunctionPtr& udf) - { - Dictionary dict = SerializeCommonAttributes(*udf, s_serializationVersion, s_userDefinedFunctionTypeValue); - dict[userDefinedStateKey] = udf->Serialize(); - return dict; - } - - /*static*/ FunctionPtr UDFUtils::Deserialize(const Dictionary& dict, - const unordered_map& uidToVariableMap, - const DeviceDescriptor& device, - const UDFDeserializeCallback& callback) - { - static const vector s_requiredDictionaryKeys = { typeKey, uidKey, inputsKey, userDefinedStateKey }; - ValidateDictionary(dict, s_requiredDictionaryKeys, s_userDefinedFunctionTypeValue, s_serializationVersion); - - const auto& uid = dict[uidKey].Value(); - std::wstring name = L""; - if (dict.Contains(nameKey)) - name = dict[nameKey].Value(); - - auto inputs = GetInputVariables(dict, uidToVariableMap, s_serializationVersion); - - auto state = dict[userDefinedStateKey].Value(); - - auto udf = callback(inputs, name, state); - - if (udf == nullptr) - { - RuntimeError("Unable to reconstruct a user-defined function. Please make sure to specify a valid UDF deserializer."); - } - - // Restore the original uid, which other functions in the graph depend on - // (their inputs refer to the uids of this UDF outputs, which are generated base on the uid of this UDF). - udf->m_uid = uid; - - return udf; - } } diff --git a/Source/CNTKv2LibraryDll/PrimitiveFunction.h b/Source/CNTKv2LibraryDll/PrimitiveFunction.h index 713f1a9eb..e3a7b5e4a 100644 --- a/Source/CNTKv2LibraryDll/PrimitiveFunction.h +++ b/Source/CNTKv2LibraryDll/PrimitiveFunction.h @@ -263,7 +263,7 @@ namespace CNTK protected: PrimitiveFunction(PrimitiveOpType op, const std::vector& inputs, Dictionary&& functionConfig, const std::wstring& functionName, const std::wstring& uid) - : Function(inputs, std::move(functionConfig), functionName, uid), m_op(op) + : Function(inputs, std::move(functionConfig), nullptr, functionName, uid), m_op(op) {} public: @@ -759,21 +759,7 @@ namespace CNTK static const size_t s_serializationVersion = 13; }; - class UDFUtils - { - public: - - static bool IsUDF(const FunctionPtr& f); - - static bool IsUDF(const Dictionary& dict); - - static Dictionary Serialize(const FunctionPtr& dictionary); - - static FunctionPtr Deserialize(const Dictionary& dictionary, - const std::unordered_map& uidToVariableMap, - const CNTK::DeviceDescriptor& device, - const UDFDeserializeCallback& callback); - - static const size_t s_serializationVersion = 0; - }; + std::vector GetInputUids(const Function& f); + Dictionary SerializeCommonFunctionAttributes(const Function& f, size_t version, const std::wstring& functionType); + std::vector GetInputVariables(const Dictionary& dict, const std::unordered_map& uidToVariableMap, size_t currentSerializationVersion); } diff --git a/Source/CNTKv2LibraryDll/Serialization.h b/Source/CNTKv2LibraryDll/Serialization.h index 68012e4af..d7efb8e7f 100644 --- a/Source/CNTKv2LibraryDll/Serialization.h +++ b/Source/CNTKv2LibraryDll/Serialization.h @@ -45,6 +45,9 @@ namespace CNTK const std::wstring internalWorkerStateKey = L"internal_worker_state"; const std::wstring externalWorkerStateKey = L"external_worker_state"; const std::wstring userDefinedStateKey = L"user_defined_state"; + const std::wstring udfModuleNameKey = L"module"; + const std::wstring udfFactoryMethodNameKey = L"deserialize_method"; + const std::wstring nativeUDFKey = L"native"; template inline std::string GetVersionsString(size_t currentVersion, size_t dictVersion) diff --git a/Source/CNTKv2LibraryDll/UserDefinedFunction.cpp b/Source/CNTKv2LibraryDll/UserDefinedFunction.cpp new file mode 100644 index 000000000..9f0fc547b --- /dev/null +++ b/Source/CNTKv2LibraryDll/UserDefinedFunction.cpp @@ -0,0 +1,178 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE.md file in the project root for full license information. +// + +#include "UserDefinedFunction.h" +#include "UserFunctionFactory.h" +#include "Serialization.h" +#include "PrimitiveFunction.h" +#include "CompositeFunction.h" + + +namespace CNTK +{ + + static Internal::UDFDeserializeCallbackWrapperPtr s_SWIGCallbackWrapper; + + void Internal::RegisterUDFDeserializeCallbackWrapper(UDFDeserializeCallbackWrapperPtr callbackPtr) + { + s_SWIGCallbackWrapper = callbackPtr; + } + + static const std::wstring s_nativeUDFTypeValue = L"NativeUserDefinedFunction" ; + + static std::unordered_map> s_deserializedUDFsRegistry; + + Function::Function(const std::vector& inputs, const std::wstring& name) + : Function(inputs, {}, name) + {} + + Function::Function(const std::vector& inputs, const Dictionary& functionConfig, const std::wstring& name) + : Function(inputs, functionConfig, nullptr, name, Internal::GenerateUid(L"UserDefinedFunction")) + {} + + /*static*/ FunctionPtr Function::DeserializeNativeImpl(const std::vector& inputs, const std::wstring& name, const Dictionary& dict) + { + static const vector s_requiredDictionaryKeys = { userDefinedStateKey, udfModuleNameKey, udfFactoryMethodNameKey, opKey }; + ValidateDictionary(dict, s_requiredDictionaryKeys, s_nativeUDFTypeValue, s_serializationVersion); + + auto state = dict[userDefinedStateKey].Value(); + auto opName = dict[opKey].Value(); + auto moduleName = dict[udfModuleNameKey].Value(); + auto methodName = dict[udfFactoryMethodNameKey].Value(); + + FunctionPtr udf = nullptr; + + auto callback = Function::GetUDFDeserializeCallback(opName); + if (callback != nullptr) + { + udf = callback->operator()(inputs, name, state); + } + else + { + Microsoft::MSR::CNTK::Plugin plugin; + auto factoryMethod = (UserFunctionFactoryMethodType)(plugin.Load(moduleName, msra::strfun::utf8(methodName), /*isCNTKPlugin =*/ false)); + udf = FunctionPtr(factoryMethod(inputs.data(), inputs.size(), &state, name.c_str())); + } + + if (udf == nullptr) + { + RuntimeError("Unable to reconstruct the native UserFunction with op name '%S'", opName.c_str()); + } + + s_deserializedUDFsRegistry[opName] = { moduleName, methodName }; + + return udf; + } + + /*virtual*/ std::wstring Function::ModuleName() const + { + auto it = s_deserializedUDFsRegistry.find(OpName()); + if (it != s_deserializedUDFsRegistry.end()) + { + auto moduleAndMethodPair = it->second; + return moduleAndMethodPair.first; + } + + // this op name was never registered in the s_deserializedUDFsRegistry (which only happens during the deserialization), + // then use user factory as a fallback (this udf must have been registed, so that an instance could be created). + return s_userFunctionFactory->GetModuleName(OpName()); + } + + /*virtual*/ std::wstring Function::DeserializeMethodName() const + { + auto it = s_deserializedUDFsRegistry.find(OpName()); + if (it != s_deserializedUDFsRegistry.end()) + { + auto moduleAndMethodPair = it->second; + return moduleAndMethodPair.second; + } + + // this op name was never registered in the s_deserializedUDFsRegistry (which only happens during the deserialization), + // then use user factory as a fallback (this udf must have been registed, so that an instance could be created). + return Function::s_userFunctionFactory->GetFactoryMethodName(OpName()); + } + + Dictionary Function::SerializeNativeImpl() const + { + Dictionary dict; + dict[userDefinedStateKey] = Serialize(); + dict[udfModuleNameKey] = ModuleName(); + dict[udfFactoryMethodNameKey] = DeserializeMethodName(); + dict[opKey] = OpName(); + dict[versionKey] = s_serializationVersion; + dict[typeKey] = s_nativeUDFTypeValue; + return dict; + } + + static const std::wstring s_userDefinedFunctionTypeValue = L"UserDefinedFunction"; + + /*static*/ bool UDFUtils::IsUDF(const FunctionPtr& f) + { + return (dynamic_cast(f.get()) == nullptr) && + (dynamic_cast(f.get()) == nullptr); + } + + /*static*/ bool UDFUtils::IsUDF(const Dictionary& dict) + { + return (dict.Contains(typeKey) && dict[typeKey].Value() == s_userDefinedFunctionTypeValue); + } + + /*static*/ bool UDFUtils::IsNativeUDF(const Dictionary& dict) + { + assert(IsUDF(dict)); + return (dict.Contains(nativeUDFKey) && dict[nativeUDFKey].Value() == true); + } + + /*static*/ Dictionary UDFUtils::Serialize(const FunctionPtr& f) + { + Dictionary dict = SerializeCommonFunctionAttributes(*f, s_serializationVersion, s_userDefinedFunctionTypeValue); + bool native = f->IsNative(); + dict[nativeUDFKey] = native; + dict[userDefinedStateKey] = (native) ? f->SerializeNativeImpl() : f->Serialize(); + return dict; + } + + /*static*/ FunctionPtr UDFUtils::Deserialize(const Dictionary& dict, + const unordered_map& uidToVariableMap, + const DeviceDescriptor& device) + { + static const vector s_requiredDictionaryKeys = { typeKey, uidKey, inputsKey, userDefinedStateKey }; + ValidateDictionary(dict, s_requiredDictionaryKeys, s_userDefinedFunctionTypeValue, s_serializationVersion); + + const auto& uid = dict[uidKey].Value(); + std::wstring name = L""; + if (dict.Contains(nameKey)) + name = dict[nameKey].Value(); + + auto inputs = GetInputVariables(dict, uidToVariableMap, s_serializationVersion); + + auto state = dict[userDefinedStateKey].Value(); + + FunctionPtr udf; + + if (IsNativeUDF(dict)) + { + udf = Function::DeserializeNativeImpl(inputs, name, state); + } + else if (s_SWIGCallbackWrapper != nullptr) + { + // If we're being called from SWIG, the actual deserializer should be registered by + // the target language CNTK implementation (i.e., cnkt_py for Python) + udf = s_SWIGCallbackWrapper->operator()(inputs, name, state); + } + + if (udf == nullptr) + { + RuntimeError("Unable to reconstruct a user-defined function (name = %S, uid = %S). " + "Please make sure to specify a valid UDF deserializer.", name.c_str(), uid.c_str()); + } + + // Restore the original uid, which other functions in the graph depend on + // (their inputs refer to the uids of this UDF outputs, which are generated base on the uid of this UDF). + udf->m_uid = uid; + + return udf; + } +} \ No newline at end of file diff --git a/Source/CNTKv2LibraryDll/UserDefinedFunction.h b/Source/CNTKv2LibraryDll/UserDefinedFunction.h new file mode 100644 index 000000000..a7205d733 --- /dev/null +++ b/Source/CNTKv2LibraryDll/UserDefinedFunction.h @@ -0,0 +1,32 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE.md file in the project root for full license information. +// + +#pragma once + +#include "stdafx.h" +#include "CNTKLibrary.h" + +namespace CNTK +{ + class UDFUtils + { + public: + + static bool IsUDF(const FunctionPtr& f); + + static bool IsUDF(const Dictionary& dict); + + static Dictionary Serialize(const FunctionPtr& f); + + static FunctionPtr Deserialize(const Dictionary& dictionary, + const std::unordered_map& uidToVariableMap, + const CNTK::DeviceDescriptor& device); + + static const size_t s_serializationVersion = 0; + + private: + static bool IsNativeUDF(const Dictionary& dict); + }; +} diff --git a/Source/CNTKv2LibraryDll/UserFunctionFactory.h b/Source/CNTKv2LibraryDll/UserFunctionFactory.h index 56bccca8a..1302339d7 100644 --- a/Source/CNTKv2LibraryDll/UserFunctionFactory.h +++ b/Source/CNTKv2LibraryDll/UserFunctionFactory.h @@ -11,30 +11,60 @@ namespace CNTK { + typedef Function* (*UserFunctionFactoryMethodType)(const Variable* operands, size_t numOperands, const Dictionary* attributes, const wchar_t* name); + class UserFunctionFactory : public std::enable_shared_from_this { - typedef Function* (*UserFunctionFactoryMethodType)(const Variable* operands, size_t numOperands, const Dictionary* attributes, const wchar_t* name); - public: + + bool IsRegistered(const std::wstring& uniqueOpName) + { + std::unique_lock lock(m_mutex); + return IsRegisteredImpl(uniqueOpName); + } + void Register(const std::wstring& uniqueOpName, const std::wstring& moduleName, const std::wstring& factoryMethodName) { std::unique_lock lock(m_mutex); - if (m_userFunctionFactoryMethodMap.find(uniqueOpName) != m_userFunctionFactoryMethodMap.end()) + if (IsRegisteredImpl(uniqueOpName)) InvalidArgument("UserFunction with op name '%S' is already registered. All UserFunction op names must be unique.", uniqueOpName.c_str()); m_userFunctionFactoryMethodMap[uniqueOpName] = std::make_shared(moduleName, factoryMethodName); } + std::wstring GetModuleName(const std::wstring& uniqueOpName) + { + std::unique_lock lock(m_mutex); + if (!IsRegisteredImpl(uniqueOpName)) + InvalidArgument("UserFunction with op name '%S' has not been registered.", uniqueOpName.c_str()); + + return m_userFunctionFactoryMethodMap.at(uniqueOpName)->m_moduleName; + } + + std::wstring GetFactoryMethodName(const std::wstring& uniqueOpName) + { + std::unique_lock lock(m_mutex); + if (!IsRegisteredImpl(uniqueOpName)) + InvalidArgument("UserFunction with op name '%S' has not been registered.", uniqueOpName.c_str()); + + return m_userFunctionFactoryMethodMap.at(uniqueOpName)->m_factoryMethodName; + } + FunctionPtr CreateInstance(const std::wstring& opName, const std::vector& inputs, const Dictionary& functionConfig, const std::wstring& userFunctionInstanceName) { std::unique_lock lock(m_mutex); - if (m_userFunctionFactoryMethodMap.find(opName) == m_userFunctionFactoryMethodMap.end()) + if (!IsRegisteredImpl(opName)) InvalidArgument("UserFunction with op name '%S' has not been registered.", opName.c_str()); return std::shared_ptr(m_userFunctionFactoryMethodMap.at(opName)->m_factoryMethod(inputs.data(), inputs.size(), &functionConfig, userFunctionInstanceName.c_str())); } private: + bool IsRegisteredImpl(const std::wstring& uniqueOpName) + { + return (m_userFunctionFactoryMethodMap.find(uniqueOpName) != m_userFunctionFactoryMethodMap.end()); + } + struct UserFunctionFactoryMethodInfo : public std::enable_shared_from_this { UserFunctionFactoryMethodInfo(const std::wstring& moduleName, const std::wstring& factoryMethodName) @@ -53,4 +83,4 @@ namespace CNTK std::mutex m_mutex; std::unordered_map m_userFunctionFactoryMethodMap; }; -} +} \ No newline at end of file diff --git a/bindings/csharp/Swig/cntk_cs.i b/bindings/csharp/Swig/cntk_cs.i index 9d9941967..48c194d2c 100755 --- a/bindings/csharp/Swig/cntk_cs.i +++ b/bindings/csharp/Swig/cntk_cs.i @@ -1927,25 +1927,34 @@ SWIG_STD_VECTOR_ENHANCED(CNTK::DeviceDescriptor) -%ignore CNTK::Function::Load(const std::wstring& filepath, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice(), const Internal::UDFDeserializerPtr& deserializer); -%ignore CNTK::Function::Load(const char* buffer, size_t length, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice(), const Internal::UDFDeserializerPtr& deserializer = nullptr); +%ignore CNTK::Function::Load(const std::wstring& filepath, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice()); +%ignore CNTK::Function::Load(const char* buffer, size_t length, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice()); // Ignore exposing istream to C# for now. Todo: find a good solution to map C# System.IO.Stream to std::istream. -%ignore CNTK::Function::Load(std::istream& inputStream, const DeviceDescriptor& computeDevice= DeviceDescriptor::UseDefaultDevice(), const Internal::UDFDeserializerPtr& deserializer = nullptr); +%ignore CNTK::Function::Load(std::istream& inputStream, const DeviceDescriptor& computeDevice= DeviceDescriptor::UseDefaultDevice()); %extend CNTK::Function { static FunctionPtr Load(const std::wstring& filepath, const CNTK::DeviceDescriptor& computeDevice = CNTK::DeviceDescriptor::UseDefaultDevice()) { - return CNTK::Function::Load(filepath, computeDevice, nullptr); + return CNTK::Function::Load(filepath, computeDevice); } static FunctionPtr Load(const char* modelBuffer, size_t length, const CNTK::DeviceDescriptor& computeDevice = CNTK::DeviceDescriptor::UseDefaultDevice()) { - return CNTK::Function::Load(modelBuffer, length, computeDevice, nullptr); + return CNTK::Function::Load(modelBuffer, length, computeDevice); } } + +%ignore CNTK::Function::RegisterUDFDeserializeCallback; +%ignore CNTK::Function::GetUDFDeserializeCallback; +%ignore_class CNTK::Internal::UDFDeserializeCallbackWrapper; +%ignore_function CNTK::Internal::RegisterUDFDeserializeCallbackWrapper; +%ignore_function CNTK::Internal::IsNativeUserFunctionRegistered; + + + %include "CNTKLibraryInternals.h" %include "CNTKLibrary.h" diff --git a/bindings/python/cntk/cntk_py.i b/bindings/python/cntk/cntk_py.i index 6bfd66277..e2f4b1b1c 100644 --- a/bindings/python/cntk/cntk_py.i +++ b/bindings/python/cntk/cntk_py.i @@ -21,7 +21,7 @@ %rename(_add_progress_writers) CNTK::Internal::AddProgressWriters; %rename(_backward) CNTK::Function::Backward; %rename(_infer_outputs) CNTK::Function::InferOutputs; -%rename(_serialize) CNTK::Function::Serialize; +%rename(_serialize_impl) CNTK::Function::Serialize; %rename(_deserialize) CNTK::Function::Deserialize; %rename(_update) CNTK::Learner::Update; %rename(sgd_learner) CNTK::SGDLearner; @@ -39,9 +39,9 @@ %rename(ctf_deserializer) CNTK::CTFDeserializer; %rename(htk_feature_deserializer) CNTK::HTKFeatureDeserializer; %rename(htk_mlf_deserializer) CNTK::HTKMLFDeserializer; -%rename(_infer_outputs) CNTK::Function::InferOutputs; %rename(_stream_infos) CNTK::SwigMinibatchSource::StreamInfos(PyObject*); %rename(_next_minibatch) CNTK::SwigMinibatchSource::_GetNextMinibatch; +%rename(_register_udf_deserialize_callback) CNTK::Internal::RegisterUDFDeserializeCallbackWrapper; %rename(_none) CNTK::DictionaryValue::Type::None; @@ -174,7 +174,8 @@ %ignore CNTK::Internal::TensorBoardFileWriter::TensorBoardFileWriter(const std::wstring& dir, const ::Microsoft::MSR::CNTK::ComputationNetworkPtr& modelToVisualize = nullptr); %ignore CNTK::Internal::Convolution; -%ignore CNTK::Function::Function(const std::vector& inputs, Dictionary&& functionConfig, const std::wstring& name = L"", const std::wstring& uid = Internal::GenerateUid(L"UserDefinedFunction")); +%ignore CNTK::Function::RegisterUDFDeserializeCallback; +%ignore CNTK::Function::GetUDFDeserializeCallback; %{ #define SWIG_FILE_WITH_INIT @@ -617,17 +618,32 @@ public: } } -// Callback support + %feature("director") CNTK::Function; -%feature("director") CNTK::Internal::UDFDeserializeCallbackWrapper; %feature("nodirector") CNTK::Function::OnPlaceholdersReplaced; %feature("nodirector") CNTK::Function::OpName; +// Callback support +%feature("director") CNTK::Internal::UDFDeserializeCallbackWrapper; + +%typemap(directorout) std::shared_ptr (void * swig_argp, int swig_res = 0) { + if ($input == Py_None) { + $result = $ltype(); + } else { + swig_res = SWIG_ConvertPtr($input, &swig_argp, $descriptor(std::shared_ptr *), %convertptr_flags); + if (!SWIG_IsOK(swig_res)) { + %dirout_fail(swig_res,"$type"); + } + $result = *(%reinterpret_cast(swig_argp, $<ype)); + } +} + +// Since there're three overloads of Function::Load, both "rename" and "compactdefaultargs" are needed +// to make pybuffer_binary work correctly with default arguments. +%rename(load_from_buffer) CNTK::Function::Load(const char*, size_t, const DeviceDescriptor&); +%feature("compactdefaultargs") CNTK::Function::Load(const char*, size_t, const DeviceDescriptor&); // This overload is not used in python at the moment. -%ignore CNTK::Function::Load(const std::wstring&, const DeviceDescriptor&, const UDFDeserializeCallback&); -%ignore CNTK::Function::Load(const char*, size_t, const DeviceDescriptor&, const UDFDeserializeCallback&); -%ignore CNTK::Function::Load(std::istream&, const DeviceDescriptor&, const UDFDeserializeCallback&); -%rename(load_from_buffer) CNTK::Function::Load(const char*, size_t, const CNTK::DeviceDescriptor&, const CNTK::Internal::UDFDeserializeCallbackWrapper&); +%ignore CNTK::Function::Load(std::istream&, const DeviceDescriptor&); %feature("director") CNTK::Learner; %feature("nodirector") CNTK::Learner::Parameters; @@ -812,7 +828,7 @@ public: } %enddef -// Implementing typemapping for UDFDeserializeCallback (it has a dictionary +// Implementing typemapping for UDFDeserializeCallbackWrapper (it has a dictionary // as one of its input parameters), which needs to be implemented in Python. %typemap(directorin, fragment="DictionaryValueToPy") const CNTK::Dictionary& @@ -1403,6 +1419,7 @@ std::unordered_map; %template(random_uniform_double) CNTK::NDArrayView::RandomUniform; %template(DictionaryValueFromDict) CNTK::DictionaryValue::DictionaryValue; +%template(DictionaryValueFromNDArrayView) CNTK::DictionaryValue::DictionaryValue; %template(training_parameter_per_sample_schedule) CNTK::TrainingParameterPerUnitSchedule::UnitType::Sample>; %template(training_parameter_per_minibatch_schedule) CNTK::TrainingParameterPerUnitSchedule::UnitType::Minibatch>; diff --git a/bindings/python/cntk/internal/__init__.py b/bindings/python/cntk/internal/__init__.py index 5aa32806a..aa60a52b5 100644 --- a/bindings/python/cntk/internal/__init__.py +++ b/bindings/python/cntk/internal/__init__.py @@ -23,31 +23,38 @@ def _value_as_sequence_or_array(val, var): map_if_possible(val) return val.asarray() +_serialization_version = 1 + +def _serialize(udf): + dictionary = {} + dictionary['class'] = udf.__class__.__name__ + dictionary['module'] = udf.__class__.__module__ + dictionary['op_name'] = udf.op_name + dictionary['state'] = udf.serialize() + dictionary['version'] = _serialization_version + return dictionary + + class _UDFDeserializeCallbackWrapper(cntk_py.UDFDeserializeCallbackWrapper): - ''' - Provides an implementation of the UDFDeserializer interface used - to inflate user defined functions in a model dictionary. - ''' def __init__(self, factory_callback_map=None): super(_UDFDeserializeCallbackWrapper, self).__init__() self.factory_callback_map = factory_callback_map - self.__disown__() def __call__(self, inputs, name, dictionary): - import pdb; pdb.set_trace() cls = dictionary['class'] module = dictionary['module'] state = dictionary['state'] op_name = dictionary['op_name'] + deserialize_method = 'deserialize' if (self.factory_callback_map and op_name in self.factory_callback_map): factory = self.factory_callback_map[op_name] else: exec("from {} import {}".format(module, cls)) - eval_str = "{0}.deserialize if hasattr({0}, 'deserialize') else None" - factory = eval(eval_str.format(cls)) + eval_str = "{0}.{1} if hasattr({0}, '{1}') else None" + factory = eval(eval_str.format(cls, deserialize_method)) - if (factory): + if factory: return factory(list(inputs), name, state) raise ValueError("Cannot deserialize user function '{}.{}'. " diff --git a/bindings/python/cntk/internal/sanitize.py b/bindings/python/cntk/internal/sanitize.py index 24de6e809..5e31eb83b 100644 --- a/bindings/python/cntk/internal/sanitize.py +++ b/bindings/python/cntk/internal/sanitize.py @@ -18,6 +18,15 @@ def is_string(s): ''' return isinstance(s, ("".__class__, u"".__class__)) +def is_byte_buffer(s): + ''' + Tests whether ``s`` is a byte buffer (not a string) in a way that + works on Python 2 and 3. + ''' + return (isinstance(s, bytearray) or + (isinstance(s, type(b'')) and not isinstance(b'', str))) + + def _as_tuple(x): ''' Convert an argument to a tuple. diff --git a/bindings/python/cntk/internal/utils.py b/bindings/python/cntk/internal/utils.py index 50e35218c..fa8fd88b7 100644 --- a/bindings/python/cntk/internal/utils.py +++ b/bindings/python/cntk/internal/utils.py @@ -6,6 +6,8 @@ from .. import cntk_py import numpy as np +from cntk import NDArrayView +from ..cntk_py import DictionaryValueFromDict, DictionaryValue, Dictionary, DictionaryValueFromNDArrayView _VARIABLE_OR_FUNCTION = (cntk_py.Variable, cntk_py.Function) @@ -206,7 +208,6 @@ def _py_dict_to_cntk_dict(py_dict): cntk_py.Dictionary: A :class:`~cntk.cntk_py.Dictionary` that has been converted from the input `dict` ''' - from ..cntk_py import DictionaryValueFromDict, DictionaryValue, Dictionary def _to_cntk_dict_value(py_value): if isinstance(py_value, dict): return DictionaryValueFromDict(_py_dict_to_cntk_dict(py_value)) @@ -214,7 +215,14 @@ def _py_dict_to_cntk_dict(py_dict): if isinstance(py_value, list): py_list = list(map(_to_cntk_dict_value, py_value)) return DictionaryValue(py_list) - + + if isinstance(py_value, np.ndarray): + py_value = NDArrayView.from_dense(py_value) + return DictionaryValueFromNDArrayView(py_value) + + if py_value is None: + return DictionaryValue() + return DictionaryValue(py_value) res = Dictionary() diff --git a/bindings/python/cntk/ops/functions.py b/bindings/python/cntk/ops/functions.py index 6305d717e..731554c9b 100644 --- a/bindings/python/cntk/ops/functions.py +++ b/bindings/python/cntk/ops/functions.py @@ -13,7 +13,8 @@ from cntk.internal import map_if_possible, typemap, sanitize_var_map,\ _value_as_sequence_or_array from cntk.internal.utils import get_python_function_arguments, \ map_function_arguments, _py_dict_to_cntk_dict -from cntk.internal import _UDFDeserializeCallbackWrapper +from cntk.internal import _UDFDeserializeCallbackWrapper, _serialize +from cntk.internal.sanitize import is_byte_buffer from ..variables import Record, Variable @@ -90,6 +91,10 @@ class Function(cntk_py.Function): in :func:`~cntk.ops.as_block()`, using the supplied ``op_name`` and ``name`` parameters, which are otherwise ignored. ''' + _udf_callback_map = {} + _deserializer = _UDFDeserializeCallbackWrapper(_udf_callback_map) + cntk_py._register_udf_deserialize_callback(_deserializer) + # We override the constructors to implement an overload that constructs # a CNTK Functions from a Python function (@Function). def __new__(cls, *args, **kwargs): @@ -1135,25 +1140,43 @@ class Function(cntk_py.Function): 'restore(...) instead', DeprecationWarning) return self.restore(filename) + @staticmethod + def register_udf_deserialize_callback(op_name, callback): + ''' + Register a callback function to be invoked when deserializing a user- + defined function with the corresponding op name. + + When loading a model, CNTK will try to automatically reconstruct any + (non-native) user-defined functions by invoking a static + :func:`~cntk.ops.functions.UserFunction.deserialize` method of the + corresponding UserFunction sub-class. This method allows to override + default UDF deserialization behavior by specifying a user- defined + function op name and the corresponding callback that should be invoked + instead of the ``deserialize`` method. + + Args: + op_name (str): unique op name of the user-defined function. + callback (function): a function taking three arguments (a list of + inputs to the UserFunction, a string name, and a state dictionary + generated by the corresponding :func:`~cntk.ops.functions.UserFunction.serialize` + method) and returns an instance of the user-defined function. + ''' + if op_name in Function._udf_callback_map: + raise ValueError("A callback for the UserFunction with op name {}" + " has already been registered.".format(op_name)); + Function._udf_callback_map[op_name] = callback + @staticmethod @typemap - def load(model, device=None, udf_factory_callback_map=None): + def load(model, device=None): ''' Load the ``model``, that has been saved using :func:`~cntk.ops.functions.Function.save`. Args: - model (str or bytes): either a filepath of a model file or a byte buffer + model (str, bytes or bytearray): either a file path of a model file or a byte buffer containing the binary representation of a model. device (:class:`~cntk.device.DeviceDescriptor`, defaults to the current globally default device): specifies the device to allocate the model on. - udf_factory_callback_map (dict, default is `None`): if the model contains any user-defined - functions, CNTK will try to automatically reconstruct them by invoking a static - ``deserialize`` method of the corresponding Function sub-class. This method takes three - arguments (a list of inputs to the function, a string name, and a state dictionary - generated by the corresponding :func:`~cntk.ops.functions.UserFunction.serialize` method) and - returns an instance of the user-defined function. This optional argument allows to override - default UDF deserialization behavior by providing a map of user-function op names and - corresponding lambdas that should be invoked instead of the ``deserialize`` method. Returns: root node @@ -1161,10 +1184,7 @@ class Function(cntk_py.Function): if not device: device = DeviceDescriptor.use_default_device() - deserializer = _UDFDeserializeCallbackWrapper(udf_factory_callback_map) - - is_buffer = isinstance(model, type(b'')) and not isinstance(b'', str) - is_buffer = is_buffer or isinstance(model, bytearray) + is_buffer = is_byte_buffer(model) is_file = False if not is_buffer: @@ -1174,10 +1194,10 @@ class Function(cntk_py.Function): pass if is_buffer: - return cntk_py.Function.load_from_buffer(model, device, deserializer) + return cntk_py.Function.load_from_buffer(model, device) if is_file: - return cntk_py.Function.load(model, device, deserializer) + return cntk_py.Function.load(model, device) raise ValueError('Cannot load a model that is neither a file nor a byte buffer.') @@ -1225,11 +1245,11 @@ def native_user_function(op_name, operands, attributes=None, user_function_insta return cntk_py.Function_native_user_function(op_name, operands, attributes, user_function_instance_name) @typemap -def load_model(model, device=None, udf_factory_callback_map=None): +def load_model(model, device=None): ''' Alias for :func:`~cntk.ops.functions.Function.load`. ''' - return Function.load(model, device, udf_factory_callback_map) + return Function.load(model, device) @typemap def save_model(model, filename): # legacy name @@ -1251,8 +1271,10 @@ class UserFunction(Function): `False` passes the data as CNTK Value objects. name (str): name of this function ''' + def __init__(self, inputs, as_numpy=True, name=''): super(UserFunction, self).__init__(inputs, name) + self.set_native(False) self.as_numpy = as_numpy # Since the state will frequently not be used, we cache the None-state @@ -1402,14 +1424,36 @@ class UserFunction(Function): ''' raise NotImplementedError('clone has to be overwritten') - def _serialize(self): - dictionary = {} - dictionary['class'] = self.__class__.__name__ - dictionary['module'] = self.__class__.__module__ - dictionary['op_name'] = self.op_name - dictionary['state'] = self.serialize() + def _serialize_impl(self): + dictionary = _serialize(self) return _py_dict_to_cntk_dict(dictionary) + @staticmethod + def deserialize(inputs, name, state): + ''' + A stub deserialize method for illustration purposes. User-defined functions + need to provide their own implementation in order for CNTK to be able to + reconstruct them when loading a model. + + Args: + inputs (list): a list of inputs to the function + name (str): name of this function + state (dict): a state dictionary generated by the corresponding + :func:`~cntk.ops.functions.UserFunction.serialize` method. + + Returns: + An instance of the user-defined function. + ''' + raise NotImplementedError('a stub method for illustration purposes.') + + @property + def op_name(self): + ''' + Unique operation name of this user-defined function. + This property defaults to '.', but can be overridden. + ''' + return self.__class__._op_name() + def serialize(self): ''' Generates a dictionary that captures the state of this user-defined function. @@ -1418,3 +1462,7 @@ class UserFunction(Function): to be preserved in the model dictionary. ''' return {} + + @classmethod + def _op_name(cls): + return cls.__module__ + '.' + cls.__name__ diff --git a/bindings/python/cntk/ops/tests/userfunction_test.py b/bindings/python/cntk/ops/tests/userfunction_test.py index 768e86e5c..1761ee9df 100644 --- a/bindings/python/cntk/ops/tests/userfunction_test.py +++ b/bindings/python/cntk/ops/tests/userfunction_test.py @@ -49,10 +49,8 @@ class MyPlus(UserFunction): input_gradients[input] = root_gradients def serialize(self): - internal_state = {} - internal_state['forward_calls'] = self.forward_calls - internal_state['backward_calls'] = self.backward_calls - return internal_state + return {'forward_calls' : self.forward_calls, + 'backward_calls' : self.backward_calls} @staticmethod def deserialize(inputs, name, state): @@ -61,6 +59,24 @@ class MyPlus(UserFunction): f.backward_calls = state['backward_calls'] return f + +class MyPlusPlus(MyPlus): + def __init__(self, inputs, name, state={}): + super(MyPlusPlus, self).__init__(inputs[0], inputs[1], name=name+name) + + def forward(self, *args, **kwargs): + r1 = super(MyPlusPlus, self).forward(*args, **kwargs) + r2 = super(MyPlusPlus, self).forward(*args, **kwargs) + return None, r1[1] + r2[1] + + def serialize(self): + return None + + @staticmethod + def deserialize(*args): + return MyPlusPlus(*args) + + def test_ext_eval_1(): dim = 4 p = parameter(shape=(dim,), init=10, name='p') @@ -315,12 +331,14 @@ def test_ext_lambdafunc(tmpdir): filepath = str(tmpdir / 'test_ext_lambdafunc.dat') z0.save(filepath) - z = Function.load(filepath, udf_factory_callback_map = - { - 'conditional_exec_lambda' : - lambda x, *unused: - LambdaFunc(x, when=lambda arg: np.sum(arg)>1, execute=cb.inc) - }) + + Function.register_udf_deserialize_callback( + 'conditional_exec_lambda', + lambda x, *unused: + LambdaFunc(x, when=lambda arg: np.sum(arg)>1, execute=cb.inc)) + + z = Function.load(filepath) + momentum_time_constant = momentum_as_time_constant_schedule(1100) lr_per_sample = learning_rate_schedule(0.007, UnitType.sample) trainer = Trainer(z, (z+0, z+0), \ @@ -484,18 +502,136 @@ def test_udf_output_needs_no_gradient(): assert np.allclose(result, [[[6., 8.], [10., 12.]]]) -def test_native_user_function(): - ops.register_native_user_function('NativeUserTimesFunction', 'Cntk.ExtensibilityExamples-' + C.__version__.rstrip('+'), 'CreateUserTimesFunction') +def test_native_user_function(tmpdir): + + if not cntk_py.is_native_user_function_registered('NativeUserTimesOp'): + ops.register_native_user_function('NativeUserTimesOp', 'Cntk.ExtensibilityExamples-' + C.__version__.rstrip('+'), 'CreateUserTimesFunction') + dev = cpu() x = input((2)) w = parameter((2, 2), init=np.asarray([[0.5, 2], [-0.5, 1.5]], dtype=np.float32), device=dev) - attributes = {'param_rank' : 2, 'padding' : True} - op = ops.native_user_function('NativeUserTimesFunction', [w, x], attributes, 'native_user_times_function') + attributes = { + 'param_rank' : 2, + 'padding' : True, + 'none': None, + 'nested lists' : [[1,2,3], [4,5,6]], + 'string' : 'string', + 'some data' : np.arange(1, 10, dtype=np.float32).reshape((3,3)) + } + def verify_attributes(udf): + for k,v in attributes.items(): + if not isinstance(v, np.ndarray): + assert udf.attributes[k] == v + else: + assert (udf.attributes[k] == v).all() + + op = ops.native_user_function('NativeUserTimesOp', [w, x], attributes, 'native_user_times_function') + + verify_attributes(op.owner) + + filepath = str(tmpdir / 'test_native_user_function.dat') + op.save(filepath) + + op_reloaded = Function.load(filepath, device=dev) x_data = NDArrayView.from_dense(np.asarray([[0.1, 0.2], [-0.1, 0.3]], dtype=np.float32), device=dev) - result = op.eval({x : x_data}, device=dev) + result = op_reloaded.eval({op_reloaded.arguments[0] : x_data}, device=dev) + assert np.allclose(result, [[-0.05, 0.5], [-0.2, 0.25]]) - native_times_primitive = op.find_by_name('native_user_times_function') - assert native_times_primitive.attributes == attributes - \ No newline at end of file + native_times_primitive = op_reloaded.find_by_name('native_user_times_function') + + verify_attributes(native_times_primitive) + + +def build_test_function(): + dev = cpu() + w_value = np.asarray([[0.5, 2], [-0.5, 1.5]]).astype(np.float32) + c1_value = 2.718 + c2_value = -3.141 + + if not cntk_py.is_native_user_function_registered('NativeUserTimesOp'): + ops.register_native_user_function('NativeUserTimesOp', 'Cntk.ExtensibilityExamples-' + C.__version__.rstrip('+'), 'CreateUserTimesFunction') + + x = input((2)) + + w = parameter((2, 2), init=w_value, device=dev) + + mp = MyPlus(x, constant(c1_value)) + op = user_function(MyPlus(x, constant(c1_value))) + op = ops.native_user_function('NativeUserTimesOp', [w, op], user_function_instance_name='my_times') + + return dev, w_value, c1_value, c2_value, user_function(MyPlus(op, constant(c2_value))) + +def test_both_flavors_of_user_functions(tmpdir): + dev, w_value, c1_value, c2_value, op = build_test_function() + + filepath = str(tmpdir / 'test_native_user_function.dat') + op.save(filepath) + op_reloaded = Function.load(filepath, device=dev) + + np.random.seed(1) + + for i in range(5): + x_value = np.random.random((2,2)).astype(np.float32) + x_data = NDArrayView.from_dense(x_value, device=dev) + result = op_reloaded.eval({op_reloaded.arguments[0] : x_data}, device=dev) + expected = np.matmul((x_value + c1_value), w_value) + c2_value + assert np.allclose(result, expected) + +def test_udf_checkpointing(tmpdir): + dev, w_value, c1_value, c2_value, op = build_test_function() + + label = constant(np.asarray([[1, 2], [3,4]]).astype(np.float32)) + + loss = cross_entropy_with_softmax(op, label) + eval_error = classification_error(op, label) + + lr_schedule = learning_rate_schedule(0.5, UnitType.minibatch) + learner = sgd(op.parameters, lr_schedule) + trainer = Trainer(op, (loss, eval_error), [learner]) + + trainer.train_minibatch({op.arguments[0] : np.random.random((2,2)).astype(np.float32)}, device=dev) + + filepath = str(tmpdir /'test_checkpointing.out') + + trainer.save_checkpoint(filepath, external_state={'test':'test'}) + + d = cntk_py.Dictionary.load(filepath) + assert len(d.keys()) != 0 + +def test_override_deserialize(tmpdir): + dev, w_value, c1_value, c2_value, op = build_test_function() + + filepath = str(tmpdir / 'test_override_deserialize.dat') + op.save(filepath) + + Function.register_udf_deserialize_callback( + MyPlus._op_name(), lambda *x : MyPlusPlus(*x)) + + op_reloaded = Function.load(filepath, device=dev) + + np.random.seed(1) + + for i in range(5): + x_value = np.random.random((2,2)).astype(np.float32) + x_data = NDArrayView.from_dense(x_value, device=dev) + result = op_reloaded.eval({op_reloaded.arguments[0] : x_data}, device=dev) + expected = 2 *(np.matmul(2 * (x_value + c1_value), w_value) + c2_value) + assert np.allclose(result, expected) + +def test_override_serialize(tmpdir): + dev=cpu() + a,b = 1.2322341, -0.29084; + op = MyPlusPlus([constant(a), constant(b)], '++'); + op = MyPlusPlus([op, op], '+++') + op = MyPlusPlus([op, op], '++++') + op = user_function(op) + result1 = op.eval({}, device=dev) + + filepath = str(tmpdir / 'test_udf_with_renamed_deserialize.dat') + op.save(filepath) + + op_reloaded = Function.load(filepath, device=dev) + + assert result1 == op_reloaded.eval({}, device=dev) \ No newline at end of file diff --git a/bindings/python/doc/extend.rst b/bindings/python/doc/extend.rst index 61536fd03..ca2ad332b 100644 --- a/bindings/python/doc/extend.rst +++ b/bindings/python/doc/extend.rst @@ -14,6 +14,7 @@ Implementing a custom operator in pure Python is simple matter of - implementing ``forward()`` and ``backward()``, whose signatures dependent on the number of inputs and outputs - specifying the outputs' shape, data type and dynamic axes in ``infer_outputs()`` + - providing a static ``deserialize()`` method to inflate previously saved function In the simplest case, just only one input and output, ``forward()`` takes an argument and returns a tuple of a state and the result. The state can be used to @@ -46,6 +47,10 @@ tuple, strings, etc.):: return [output_variable(self.inputs[0].shape, self.inputs[0].dtype, self.inputs[0].dynamic_axes)] + @staticmethod + def deserialize(inputs, name, state): + return = MySigmoid(inputs[0], name) + This can now be used as a normal operator like:: from cntk import user_function @@ -61,13 +66,17 @@ In case, the operator is initialized with multiple inputs, ``forward()`` 's class MyPlus(UserFunction): def __init__(self, arg1, arg2, name='f1'): super(MyPlus, self).__init__([arg1, arg2], name=name) + self.forward_calls = 0 + self.backward_calls = 0 def forward(self, arguments, device=None, outputs_to_retain=None): # No state needs to be passed to backward() so we just # pass None + self.forward_calls += 1 return None, arguments[0] + arguments[1] def backward(self, state, root_gradients): + self.backward_calls += 1 return root_gradients def infer_outputs(self): @@ -76,6 +85,17 @@ In case, the operator is initialized with multiple inputs, ``forward()`` 's # result would actually look like (considering broadcasting, etc.). return [output_variable(self.inputs[0].shape, self.inputs[0].dtype, self.inputs[0].dynamic_axes)] + def serialize(self): + return {'forward_calls' : self.forward_calls, + 'backward_calls' : self.backward_calls} + + @staticmethod + def deserialize(inputs, name, state): + f = MyPlus(inputs[0], inputs[1], name) + f.forward_calls = state['forward_calls'] + f.backward_calls = state['backward_calls'] + return f + If the UserFunction has more than one input, ``backward()`` is invoked with an additional ``variables`` argument, which contains a dictionary of Variable to the gradient data, whose values have to be set with the proper @@ -101,6 +121,13 @@ have to be set with the proper data. In addition, ``root_gradient`` in ``backward()`` is a dictionary of Variable to the root_gradient. +``deserialize()`` is invoked by CNTK to reconstruct a previously saved function. It should +have the same signature as :func:`~cntk.ops.functions.UserFunction.deserialize` method. +In case of a stateless function, it simply needs to invoke the constructor and return an +instance of the user function. However, if the function is stateful and overrides +:func:`~cntk.ops.functions.UserFunction.serialize` method, ``deserialize()`` also needs to +properly restore the function state. + Using user functions for debugging ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~