* Add a base class for User-Defined Functions
  * Add support for native udf serialization
  * Change API to accept std::function callbacks
This commit is contained in:
Alexey Reznichenko 2017-04-28 16:53:26 +02:00
Родитель 6c9d42f762
Коммит 5bc3661988
24 изменённых файлов: 754 добавлений и 241 удалений

Просмотреть файл

@ -16,41 +16,71 @@ void UserTimesFunctionExample()
auto x = InputVariable(NDShape({ inDim }), DataType::Float, { Axis::DefaultBatchAxis() });
auto userDefinedTimes = UserTimesFunction::Create(W, x, L"UserDefinedTimes");
size_t batchSize = 3;
std::vector<float> inputData(inDim * batchSize);
for (size_t i = 0; i < inputData.size(); ++i)
inputData[i] = (float)rand() / RAND_MAX;
auto inputDataValue = Value::CreateBatch(x.Shape(), inputData, device);
auto compareWithBuiltInTimes = [device, outDim, inDim](FunctionPtr times) {
size_t batchSize = 3;
std::vector<float> inputData(inDim * batchSize);
for (size_t i = 0; i < inputData.size(); ++i)
inputData[i] = (float)rand() / RAND_MAX;
std::vector<float> rootGradientData(outDim * batchSize, 1);
auto rootGradientValue = Value::CreateBatch(userDefinedTimes->Output().Shape(), rootGradientData, device);
auto input = times->Arguments()[0];
auto inputDataValue = Value::CreateBatch(input.Shape(), inputData, device);
std::unordered_map<Variable, ValuePtr> outputValues = { { userDefinedTimes->Output(), nullptr } };
auto backPropState = userDefinedTimes->Forward({ { x, inputDataValue } }, outputValues, device, { userDefinedTimes->Output() });
std::vector<float> rootGradientData(outDim * batchSize, 1);
auto rootGradientValue = Value::CreateBatch(times->Output().Shape(), rootGradientData, device);
std::unordered_map<Variable, ValuePtr> inputGradientValues = { { W, nullptr } };
userDefinedTimes->Backward(backPropState, { { userDefinedTimes->Output(), rootGradientValue } }, inputGradientValues);
auto userDefinedTimesOutputValue = outputValues[userDefinedTimes->Output()];
auto userDefinedTimesInputGradientValue = inputGradientValues[W];
std::unordered_map<Variable, ValuePtr> outputValues = { { times->Output(), nullptr } };
auto backPropState = times->Forward({ { input, inputDataValue } }, outputValues, device, { times->Output() });
// Compare against the CNTK built-in implementation
auto builtInTimes = Times(W, x, L"BuiltInTimes");
outputValues = { { builtInTimes->Output(), nullptr } };
backPropState = builtInTimes->Forward({ { x, inputDataValue } }, outputValues, device, { builtInTimes->Output() });
inputGradientValues = { { W, nullptr } };
builtInTimes->Backward(backPropState, { { builtInTimes->Output(), rootGradientValue } }, inputGradientValues);
auto builtInTimesOutputValue = outputValues[builtInTimes->Output()];
auto builtInTimesInputGradientValue = inputGradientValues[W];
const double relativeTolerance = 0.001f;
const double absoluteTolerance = 0.000001f;
auto parameter = times->Parameters()[0];
if (!Internal::AreEqual(*userDefinedTimesOutputValue, *builtInTimesOutputValue, relativeTolerance, absoluteTolerance))
std::runtime_error("UserTimesOp's Forward result does not match built-in result");
std::unordered_map<Variable, ValuePtr> inputGradientValues = { { parameter, nullptr } };
times->Backward(backPropState, { { times->Output(), rootGradientValue } }, inputGradientValues);
auto userDefinedTimesOutputValue = outputValues[times->Output()];
auto userDefinedTimesInputGradientValue = inputGradientValues[parameter];
if (!Internal::AreEqual(*userDefinedTimesInputGradientValue, *builtInTimesInputGradientValue, relativeTolerance, absoluteTolerance))
std::runtime_error("UserTimesOp's Forward result does not match built-in result");
// Compare against the CNTK built-in implementation
auto builtInTimes = Times(parameter, input, L"BuiltInTimes");
outputValues = { { builtInTimes->Output(), nullptr } };
backPropState = builtInTimes->Forward({ { input, inputDataValue } }, outputValues, device, { builtInTimes->Output() });
inputGradientValues = { { parameter, nullptr } };
builtInTimes->Backward(backPropState, { { builtInTimes->Output(), rootGradientValue } }, inputGradientValues);
auto builtInTimesOutputValue = outputValues[builtInTimes->Output()];
auto builtInTimesInputGradientValue = inputGradientValues[parameter];
const double relativeTolerance = 0.001f;
const double absoluteTolerance = 0.000001f;
if (!Internal::AreEqual(*userDefinedTimesOutputValue, *builtInTimesOutputValue, relativeTolerance, absoluteTolerance))
std::runtime_error("UserTimesOp's Forward result does not match built-in result");
if (!Internal::AreEqual(*userDefinedTimesInputGradientValue, *builtInTimesInputGradientValue, relativeTolerance, absoluteTolerance))
std::runtime_error("UserTimesOp's Forward result does not match built-in result");
};
compareWithBuiltInTimes(userDefinedTimes);
auto version = std::string(CNTK_COMPONENT_VERSION);
std::wstring wversion(version.begin(), version.end());
Function::RegisterNativeUserFunction(L"NativeUserTimesOp", L"Cntk.ExtensibilityExamples-" + wversion, L"CreateUserTimesFunction");
userDefinedTimes->Save(L"UserTimesFunctionExample.model");
auto userDefinedTimes_reloaded_1 = Function::Load(L"UserTimesFunctionExample.model", device);
compareWithBuiltInTimes(userDefinedTimes_reloaded_1);
Function::RegisterUDFDeserializeCallback(L"NativeUserTimesOp", [](const std::vector<Variable>& inputs,
const std::wstring& name,
const Dictionary& state) {
return UserTimesFunction::Create(inputs[0], inputs[1], state, name);
});
auto userDefinedTimes_reloaded_2 = Function::Load(L"UserTimesFunctionExample.model", device);
compareWithBuiltInTimes(userDefinedTimes_reloaded_2);
}
#pragma warning(pop)

Просмотреть файл

@ -21,7 +21,7 @@ public:
}
UserTimesFunction(const Variable& leftOperand, const Variable& rightOperand, const Dictionary& attributes, const std::wstring& name)
: Function({ leftOperand, rightOperand }, Dictionary(attributes), name)
: Function({ leftOperand, rightOperand }, attributes, name)
{}
private:
@ -116,11 +116,10 @@ private:
const std::wstring& OpName() const override
{
static const std::wstring opName = L"UserTimesOp";
static const std::wstring opName = L"NativeUserTimesOp";
return opName;
}
Dictionary Serialize() const override { NOT_IMPLEMENTED; }
size_t CurrentVersion() const override { NOT_IMPLEMENTED; }
void InferOutputs(std::vector<Variable>& outputs) override

Просмотреть файл

@ -469,6 +469,7 @@ CNTKLIBRARY_COMMON_SRC =\
$(SOURCEDIR)/CNTKv2LibraryDll/Function.cpp \
$(SOURCEDIR)/CNTKv2LibraryDll/PrimitiveFunction.cpp \
$(SOURCEDIR)/CNTKv2LibraryDll/CompositeFunction.cpp \
$(SOURCEDIR)/CNTKv2LibraryDll/UserDefinedFunction.cpp \
$(SOURCEDIR)/CNTKv2LibraryDll/NDArrayView.cpp \
$(SOURCEDIR)/CNTKv2LibraryDll/NDMask.cpp \
$(SOURCEDIR)/CNTKv2LibraryDll/Trainer.cpp \

Просмотреть файл

@ -2885,6 +2885,8 @@ namespace CNTK
const std::wstring& /*name*/,
const Dictionary& /*dictionary*/)> UDFDeserializeCallback;
typedef std::shared_ptr<UDFDeserializeCallback> UDFDeserializeCallbackPtr;
static auto NoOp = [] (const std::vector<Variable>&, const std::wstring&, const Dictionary&)
{
return nullptr;
@ -2907,6 +2909,7 @@ namespace CNTK
friend class Trainer;
friend Variable GetCorrespondingOutputVariableFromClone(const Variable&, const FunctionPtr&, const FunctionPtr&);
friend bool Internal::IsNativeUserFunctionRegistered(const std::wstring& uniqueOpName);
public:
@ -2975,6 +2978,20 @@ namespace CNTK
///
CNTK_API virtual void InferOutputs(std::vector<Variable>& outputs) = 0;
///
/// Returns the name of the module (dll/so) containing this function. For native functions registered through
/// a call to 'RegisterNativeUserFunction', unless overridden, this method return the value of the 'moduleName'
/// argument.
///
CNTK_API virtual std::wstring ModuleName() const;
///
/// Returns the name of the method which should be invoked to deserialize this function. For native functions
/// registered through a call to 'RegisterNativeUserFunction', unless overridden, this method return the value
/// of the 'factoryMethodName' argument. If overridden, it must have the same signature as the factory method.
///
CNTK_API virtual std::wstring DeserializeMethodName() const;
public:
// Optional overrides
@ -2997,7 +3014,7 @@ namespace CNTK
///
/// Generates a dictionary that captures the state of the Function graph underlying this Function.
///
CNTK_API virtual Dictionary Serialize() const override { return Dictionary(); }
CNTK_API virtual Dictionary Serialize() const override { return Attributes(); }
///
/// Creates a clone of this Function instance, using the specified 'inputs' that are inputs of the clone to be constructed.
@ -3052,8 +3069,7 @@ namespace CNTK
/// if deserializer was omitted). If there are no user defined functions in the model, deserializer is ignored.
///
CNTK_API static FunctionPtr Deserialize(const Dictionary& dictionary,
const ::CNTK::DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice(),
const UDFDeserializeCallback& deserializer = NoOp);
const ::CNTK::DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice());
public:
///
@ -3250,44 +3266,19 @@ namespace CNTK
/// Load a Function from a model file
///
CNTK_API static FunctionPtr Load(const std::wstring& filepath,
const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice(),
const UDFDeserializeCallback& deserializer = NoOp);
const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
///
/// Load a Function from a memory buffer
///
CNTK_API static FunctionPtr Load(const char* buffer, size_t length,
const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice(),
const UDFDeserializeCallback& deserializer = NoOp);
const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
///
/// Load a Function from an istream. The legacy V1 model is not supported.
///
CNTK_API static FunctionPtr Load(std::istream& inputStream,
const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice(),
const UDFDeserializeCallback& deserializer = NoOp);
#ifdef SWIG // for Python interop (adds callback wrapper)
static CNTK::FunctionPtr Load(const std::wstring& filepath,
const CNTK::DeviceDescriptor& computeDevice,
const CNTK::Internal::UDFDeserializeCallbackWrapper& wrapper)
{
using namespace std::placeholders;
UDFDeserializeCallback callback = std::bind(&CNTK::Internal::UDFDeserializeCallbackWrapper::operator(),
&wrapper, _1, _2, _3);
return CNTK::Function::Load(filepath, computeDevice, callback);
}
static CNTK::FunctionPtr Load(const char* buffer, size_t length,
const CNTK::DeviceDescriptor& computeDevice,
const CNTK::Internal::UDFDeserializeCallbackWrapper& wrapper)
{
using namespace std::placeholders;
UDFDeserializeCallback callback = std::bind(&CNTK::Internal::UDFDeserializeCallbackWrapper::operator(),
&wrapper, _1, _2, _3);
return CNTK::Function::Load(buffer, length, computeDevice, callback);
}
#endif
const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
///
/// Prints the entire graph underlying this Function to stderr
@ -3317,6 +3308,18 @@ namespace CNTK
///
CNTK_API static FunctionPtr NativeUserFunction(const std::wstring& opName, const std::vector<Variable>& operands, const Dictionary& functionConfig, const std::wstring& userFunctionInstanceName = L"");
///
/// Register a callback function to be invoked when deserializing a user-defined Function with the corresponding op name.
/// When loading a model, CNTK will try to automatically reconstruct user-defined Functions (for native functions, CNTK will
/// invoke the same factory method, the Function op name was registered with). This method allows to override
/// default user-defined Function deserialization behavior by specifying an op name and the corresponding callback that should be invoked
/// to inflate the Function object.
///
CNTK_API static void RegisterUDFDeserializeCallback(const std::wstring& uniqueOpName, const UDFDeserializeCallback& deserializer);
static UDFDeserializeCallbackPtr GetUDFDeserializeCallback(const std::wstring& uniqueOpName);
protected:
static bool IsArgument(const Variable& var)
{
@ -3324,9 +3327,10 @@ namespace CNTK
}
///
/// Protected constructor for derived 'Function' types to specify the actual input and output variables for the (primitive) Function instance.
/// Protected constructors for derived user-defined 'Function' types to specify the actual input and output variables for the (primitive) Function instance.
///
CNTK_API Function(const std::vector<Variable>& inputs, Dictionary&& functionConfig, const std::wstring& name = L"", const std::wstring& uid = Internal::GenerateUid(L"UserDefinedFunction"));
CNTK_API Function(const std::vector<Variable>& inputs, const Dictionary& functionConfig, const std::wstring& name = L"");
CNTK_API Function(const std::vector<Variable>& inputs, const std::wstring& name = L"");
template <typename FunctionType>
static void PreorderTraverseFunctions(const FunctionPtr& rootFunction, const FunctionType& functor, bool traverseInsideBlockFunction = false)
@ -3420,14 +3424,11 @@ namespace CNTK
// Disallow copy and move construction and assignment
Function(const Function&) = delete; Function(Function&&) = delete; Function& operator=(const Function&) = delete; Function& operator=(Function&&) = delete;
public:
CNTK_API Function(const std::vector<Variable>& inputs, const std::wstring& name = L"");
private:
static UserFunctionFactoryPtr s_userFunctionFactory;
private:
CNTK_API Function(const std::vector<Variable>& inputs, Dictionary&& functionConfig, const FunctionPtr& rootFunction, const std::wstring& name, const std::wstring& uid);
Function(const std::vector<Variable>& inputs, const Dictionary& functionConfig, const FunctionPtr& rootFunction, const std::wstring& name, const std::wstring& uid);
std::vector<Variable> m_inputs;
std::once_flag m_outputsInitFlag;
@ -3437,6 +3438,22 @@ namespace CNTK
std::wstring m_name;
std::wstring m_uid;
Dictionary m_attributes;
#ifdef SWIG
public:
void SetNative(bool native) { m_native = native; }
#endif
private:
bool IsNative() const { return m_native; }
bool m_native = true;
Dictionary SerializeNativeImpl() const;
static FunctionPtr DeserializeNativeImpl(const std::vector<Variable>& inputs, const std::wstring& name, const Dictionary& dict);
static const size_t s_serializationVersion = 1;
};
///

Просмотреть файл

@ -376,8 +376,6 @@ namespace CNTK
std::wstring m_fileName;
};
#ifdef SWIG
// SWIG callback wrapper for the UDF deserialization.
class UDFDeserializeCallbackWrapper
{
@ -385,8 +383,13 @@ namespace CNTK
virtual FunctionPtr operator()(const std::vector<Variable>&, const std::wstring&, const Dictionary&) const = 0;
virtual ~UDFDeserializeCallbackWrapper() = default;
};
#endif
typedef std::shared_ptr<UDFDeserializeCallbackWrapper> UDFDeserializeCallbackWrapperPtr;
CNTK_API void RegisterUDFDeserializeCallbackWrapper(UDFDeserializeCallbackWrapperPtr callbackPtr);
CNTK_API bool IsNativeUserFunctionRegistered(const std::wstring& uniqueOpName);
}
// Forward-declare test fixtures, so that they can be used as friends.

Просмотреть файл

@ -137,6 +137,7 @@
<ClInclude Include="PrimitiveOpType.h" />
<ClInclude Include="Serialization.h" />
<ClInclude Include="tensorboard\TensorBoardUtils.h" />
<ClInclude Include="UserDefinedFunction.h" />
<ClInclude Include="UserFunctionFactory.h" />
<ClInclude Include="Utils.h" />
<ClInclude Include="stdafx.h" />
@ -175,6 +176,7 @@
<ClCompile Include="tensorboard\TensorBoardUtils.cpp" />
<ClCompile Include="Trainer.cpp" />
<ClCompile Include="TrainingSession.cpp" />
<ClCompile Include="UserDefinedFunction.cpp" />
<ClCompile Include="Utils.cpp" />
<ClCompile Include="Value.cpp" />
<ClCompile Include="Variable.cpp" />

Просмотреть файл

@ -36,6 +36,7 @@
</ClCompile>
<ClCompile Include="ProgressWriter.cpp" />
<ClCompile Include="Evaluator.cpp" />
<ClCompile Include="UserDefinedFunction.cpp" />
</ItemGroup>
<ItemGroup>
<ClInclude Include="stdafx.h" />
@ -67,6 +68,7 @@
<ClInclude Include="BlockFunction.h" />
<ClInclude Include="Variable.h" />
<ClInclude Include="UserFunctionFactory.h" />
<ClInclude Include="UserDefinedFunction.h" />
</ItemGroup>
<ItemGroup>
<Filter Include="API">

Просмотреть файл

@ -23,6 +23,7 @@
#include "BlockFunction.h"
#include "SpecialPurposeNodes.h"
#include "SequenceReshapeNodes.h"
#include "UserDefinedFunction.h"
using namespace Microsoft::MSR::CNTK;
@ -58,7 +59,7 @@ namespace CNTK
for (auto& function : m_allPrimitiveFunctions)
{
auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(function.get());
if (!primitiveFunction->IsStateful())
if (!primitiveFunction || !primitiveFunction->IsStateful())
continue;
// TODO: same for BatchNorm
@ -85,7 +86,7 @@ namespace CNTK
for (auto& function : m_allPrimitiveFunctions)
{
auto primitiveFunction = dynamic_cast<const PrimitiveFunction*>(function.get());
if (!primitiveFunction->IsStateful())
if (!primitiveFunction || !primitiveFunction->IsStateful())
continue;
// TODO: same for BatchNorm
@ -224,7 +225,7 @@ namespace CNTK
return composite;
}
/*static*/ FunctionPtr CompositeFunction::Deserialize(const Dictionary& dict, const CNTK::DeviceDescriptor& device, const UDFDeserializeCallback& callback)
/*static*/ FunctionPtr CompositeFunction::Deserialize(const Dictionary& dict, const CNTK::DeviceDescriptor& device)
{
static const vector<std::wstring> s_requiredDictionaryKeys = { inputsKey, functionsKey };
@ -254,7 +255,7 @@ namespace CNTK
{
auto functionDict = dictionaryValue.Value<Dictionary>();
FunctionPtr root = UDFUtils::IsUDF(functionDict) ?
UDFUtils::Deserialize(functionDict, uidToInputMap, device, callback) :
UDFUtils::Deserialize(functionDict, uidToInputMap, device) :
PrimitiveFunction::Deserialize(functionDict, uidToInputMap, allPrimitiveFunctions, allPlaceholderReplacements, device);
allPrimitiveFunctions.insert(root);
@ -303,7 +304,7 @@ namespace CNTK
for (auto& function : functions)
{
auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(function.get());
if (!primitiveFunction->IsStateful())
if (!primitiveFunction || !primitiveFunction->IsStateful())
continue;
if (stateDictionary.Contains(primitiveFunction->Uid()))
@ -340,7 +341,7 @@ namespace CNTK
vector<wstring> uids;
PreorderTraverseFunctions(function.RootFunction(), [&uids](const FunctionPtr& funcPtr) {
auto primitiveFunction = dynamic_cast<const PrimitiveFunction*>(funcPtr.get());
if (primitiveFunction->IsStateful())
if (primitiveFunction && primitiveFunction->IsStateful())
{
uids.push_back(funcPtr->Uid());
}
@ -384,7 +385,7 @@ namespace CNTK
for (const auto& function : m_allPrimitiveFunctions)
{
auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(function.get());
if (!primitiveFunction->IsStateful())
if (!primitiveFunction || !primitiveFunction->IsStateful())
continue;
auto functionState = state[primitiveFunction->Uid()].Value<Dictionary>();

Просмотреть файл

@ -119,7 +119,7 @@ namespace CNTK
const std::unordered_map<Variable, Variable>& allPlaceholderReplacements,
const CNTK::DeviceDescriptor& device);
static FunctionPtr Deserialize(const Dictionary& dictionary, const CNTK::DeviceDescriptor& device, const UDFDeserializeCallback& callback);
static FunctionPtr Deserialize(const Dictionary& dictionary, const CNTK::DeviceDescriptor& device);
virtual const std::wstring& OpName() const override
{

Просмотреть файл

@ -27,6 +27,30 @@ namespace CNTK
return AsComposite(s_userFunctionFactory->CreateInstance(opName, operands, functionConfig, userFunctionInstanceName), userFunctionInstanceName);
}
bool Internal::IsNativeUserFunctionRegistered(const std::wstring& uniqueOpName)
{
return Function::s_userFunctionFactory->IsRegistered(uniqueOpName);
}
static std::unordered_map<std::wstring, UDFDeserializeCallbackPtr> udfCallbackMap;
static std::mutex udfCallbackMapMutex;
/*static*/ void Function::RegisterUDFDeserializeCallback(const std::wstring& uniqueOpName, const UDFDeserializeCallback& deserializer)
{
std::unique_lock<std::mutex> lock(udfCallbackMapMutex);
auto result = udfCallbackMap.insert({ uniqueOpName, make_shared<UDFDeserializeCallback>(deserializer) });
if (!result.second)
InvalidArgument("A callback for the UserFunction with op name '%S' has already been registered.", uniqueOpName.c_str());
}
/*static*/ UDFDeserializeCallbackPtr Function::GetUDFDeserializeCallback(const std::wstring& uniqueOpName)
{
std::unique_lock<std::mutex> lock(udfCallbackMapMutex);
if (udfCallbackMap.find(uniqueOpName) == udfCallbackMap.end())
return nullptr;
return udfCallbackMap.at(uniqueOpName);
}
std::vector<Variable>& Function::InitOutputs()
{
std::call_once(m_outputsInitFlag, [this]() {
@ -85,10 +109,6 @@ namespace CNTK
return std::shared_ptr<std::vector<Variable>>(new std::vector<Variable>(std::move(inputs)), [](std::vector<Variable>* ptr) { delete ptr; });
}
Function::Function(const std::vector<Variable>& inputs, Dictionary&& functionConfig, const std::wstring& name, const std::wstring& uid)
: Function(inputs, std::move(functionConfig), nullptr, name, uid)
{}
std::shared_ptr<std::vector<Variable>> Function::OutputsImpl() const
{
std::vector<Variable> outputs;
@ -99,11 +119,7 @@ namespace CNTK
return std::shared_ptr<std::vector<Variable>>(new std::vector<Variable>(std::move(outputs)), [](std::vector<Variable>* ptr) { delete ptr; });
}
Function::Function(const std::vector<Variable>& inputs, const std::wstring& name)
: Function(inputs, Dictionary(), name)
{}
Function::Function(const std::vector<Variable>& inputs, Dictionary&& functionConfig, const FunctionPtr& rootFunction, const std::wstring& name, const std::wstring& uid)
Function::Function(const std::vector<Variable>& inputs, const Dictionary& functionConfig, const FunctionPtr& rootFunction, const std::wstring& name, const std::wstring& uid)
: m_rootFunction(rootFunction), m_name(name), m_uid(uid), m_attributes(std::move(functionConfig))
{
for (auto inputVar : inputs)
@ -438,14 +454,14 @@ namespace CNTK
stream->flush();
}
/*static*/ FunctionPtr Function::Load(const std::wstring& filepath, const DeviceDescriptor& computeDevice, const UDFDeserializeCallback& callback)
/*static*/ FunctionPtr Function::Load(const std::wstring& filepath, const DeviceDescriptor& computeDevice)
{
auto stream = GetFstream(filepath, true);
if (!Internal::IsLegacyModel(*stream))
{
Dictionary model;
*stream >> model;
return Function::Deserialize(model, computeDevice, callback);
return Function::Deserialize(model, computeDevice);
}
else
{
@ -453,7 +469,7 @@ namespace CNTK
}
}
/*static*/ FunctionPtr Function::Load(const char *buffer, size_t length, const DeviceDescriptor& computeDevice, const UDFDeserializeCallback& callback)
/*static*/ FunctionPtr Function::Load(const char *buffer, size_t length, const DeviceDescriptor& computeDevice)
{
if ((buffer == nullptr) || (length <= 0))
InvalidArgument("The model buffer should not be null and its length should be greater than 0");
@ -474,15 +490,15 @@ namespace CNTK
modelStreamBuffer buf(buffer, length);
std::istream modelStream(&buf);
return Load(modelStream, computeDevice, callback);
return Load(modelStream, computeDevice);
}
}
/*static*/ FunctionPtr Function::Load(std::istream& inputStream, const DeviceDescriptor& computeDevice, const UDFDeserializeCallback& callback)
/*static*/ FunctionPtr Function::Load(std::istream& inputStream, const DeviceDescriptor& computeDevice)
{
Dictionary model;
inputStream >> model;
return Function::Deserialize(model, computeDevice, callback);
return Function::Deserialize(model, computeDevice);
}
void Function::Restore(const std::wstring& filepath)
@ -904,9 +920,9 @@ namespace CNTK
compositeFunction->CopyState(*restoredCompositeFunction);
}
/*static*/ FunctionPtr Function::Deserialize(const Dictionary& modelDictionary, const CNTK::DeviceDescriptor& device, const UDFDeserializeCallback& callback)
/*static*/ FunctionPtr Function::Deserialize(const Dictionary& modelDictionary, const CNTK::DeviceDescriptor& device)
{
return CompositeFunction::Deserialize(modelDictionary, device, callback);
return CompositeFunction::Deserialize(modelDictionary, device);
}
void Function::PrintGraph() const

Просмотреть файл

@ -21,6 +21,7 @@
#include "ConvolveGeometry.h"
#include "ConvolutionalNodes.h"
#include "Variable.h"
#include "UserFunctionFactory.h"
using namespace Microsoft::MSR::CNTK;
@ -963,7 +964,7 @@ namespace CNTK
}
}
static vector<DictionaryValue> GetInputUids(const Function& f)
vector<DictionaryValue> GetInputUids(const Function& f)
{
auto inputs = f.Inputs();
vector<DictionaryValue> inputUids;
@ -975,7 +976,7 @@ namespace CNTK
return inputUids;
}
static Dictionary SerializeCommonAttributes(const Function& f, size_t version, const wstring& functionType)
Dictionary SerializeCommonFunctionAttributes(const Function& f, size_t version, const wstring& functionType)
{
Dictionary dict;
dict[versionKey] = version;
@ -991,7 +992,7 @@ namespace CNTK
/*virtual*/ Dictionary PrimitiveFunction::Serialize() const
{
Dictionary dict = SerializeCommonAttributes(*this, CurrentVersion(), s_primitiveFunctionTypeValue);
Dictionary dict = SerializeCommonFunctionAttributes(*this, CurrentVersion(), s_primitiveFunctionTypeValue);
dict[opKey] = static_cast<size_t>(m_op);
dict[attributesKey] = Attributes();
@ -1018,7 +1019,7 @@ namespace CNTK
return dict;
}
static std::vector<Variable> GetInputVariables(const Dictionary& dict, const unordered_map<wstring, Variable>& uidToVariableMap, size_t currentSerializationVersion)
std::vector<Variable> GetInputVariables(const Dictionary& dict, const std::unordered_map<std::wstring, Variable>& uidToVariableMap, size_t currentSerializationVersion)
{
const auto& inputUids = dict[inputsKey].Value<vector<DictionaryValue>>();
@ -1246,54 +1247,4 @@ namespace CNTK
return dummyOutputVariable.Shape();
}
static const std::wstring s_userDefinedFunctionTypeValue = L"UserDefinedFunction";
/*static*/ bool UDFUtils::IsUDF(const FunctionPtr& f)
{
return (dynamic_cast<const PrimitiveFunction*>(f.get()) == nullptr);
}
/*static*/ bool UDFUtils::IsUDF(const Dictionary& dict)
{
return (dict.Contains(typeKey) && dict[typeKey].Value<std::wstring>() == s_userDefinedFunctionTypeValue);
}
/*static*/ Dictionary UDFUtils::Serialize(const FunctionPtr& udf)
{
Dictionary dict = SerializeCommonAttributes(*udf, s_serializationVersion, s_userDefinedFunctionTypeValue);
dict[userDefinedStateKey] = udf->Serialize();
return dict;
}
/*static*/ FunctionPtr UDFUtils::Deserialize(const Dictionary& dict,
const unordered_map<std::wstring, Variable>& uidToVariableMap,
const DeviceDescriptor& device,
const UDFDeserializeCallback& callback)
{
static const vector<std::wstring> s_requiredDictionaryKeys = { typeKey, uidKey, inputsKey, userDefinedStateKey };
ValidateDictionary<PrimitiveFunction>(dict, s_requiredDictionaryKeys, s_userDefinedFunctionTypeValue, s_serializationVersion);
const auto& uid = dict[uidKey].Value<std::wstring>();
std::wstring name = L"";
if (dict.Contains(nameKey))
name = dict[nameKey].Value<std::wstring>();
auto inputs = GetInputVariables(dict, uidToVariableMap, s_serializationVersion);
auto state = dict[userDefinedStateKey].Value<Dictionary>();
auto udf = callback(inputs, name, state);
if (udf == nullptr)
{
RuntimeError("Unable to reconstruct a user-defined function. Please make sure to specify a valid UDF deserializer.");
}
// Restore the original uid, which other functions in the graph depend on
// (their inputs refer to the uids of this UDF outputs, which are generated base on the uid of this UDF).
udf->m_uid = uid;
return udf;
}
}

Просмотреть файл

@ -263,7 +263,7 @@ namespace CNTK
protected:
PrimitiveFunction(PrimitiveOpType op, const std::vector<Variable>& inputs, Dictionary&& functionConfig, const std::wstring& functionName, const std::wstring& uid)
: Function(inputs, std::move(functionConfig), functionName, uid), m_op(op)
: Function(inputs, std::move(functionConfig), nullptr, functionName, uid), m_op(op)
{}
public:
@ -759,21 +759,7 @@ namespace CNTK
static const size_t s_serializationVersion = 13;
};
class UDFUtils
{
public:
static bool IsUDF(const FunctionPtr& f);
static bool IsUDF(const Dictionary& dict);
static Dictionary Serialize(const FunctionPtr& dictionary);
static FunctionPtr Deserialize(const Dictionary& dictionary,
const std::unordered_map<std::wstring, Variable>& uidToVariableMap,
const CNTK::DeviceDescriptor& device,
const UDFDeserializeCallback& callback);
static const size_t s_serializationVersion = 0;
};
std::vector<DictionaryValue> GetInputUids(const Function& f);
Dictionary SerializeCommonFunctionAttributes(const Function& f, size_t version, const std::wstring& functionType);
std::vector<Variable> GetInputVariables(const Dictionary& dict, const std::unordered_map<std::wstring, Variable>& uidToVariableMap, size_t currentSerializationVersion);
}

Просмотреть файл

@ -45,6 +45,9 @@ namespace CNTK
const std::wstring internalWorkerStateKey = L"internal_worker_state";
const std::wstring externalWorkerStateKey = L"external_worker_state";
const std::wstring userDefinedStateKey = L"user_defined_state";
const std::wstring udfModuleNameKey = L"module";
const std::wstring udfFactoryMethodNameKey = L"deserialize_method";
const std::wstring nativeUDFKey = L"native";
template <typename T>
inline std::string GetVersionsString(size_t currentVersion, size_t dictVersion)

Просмотреть файл

@ -0,0 +1,178 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
#include "UserDefinedFunction.h"
#include "UserFunctionFactory.h"
#include "Serialization.h"
#include "PrimitiveFunction.h"
#include "CompositeFunction.h"
namespace CNTK
{
static Internal::UDFDeserializeCallbackWrapperPtr s_SWIGCallbackWrapper;
void Internal::RegisterUDFDeserializeCallbackWrapper(UDFDeserializeCallbackWrapperPtr callbackPtr)
{
s_SWIGCallbackWrapper = callbackPtr;
}
static const std::wstring s_nativeUDFTypeValue = L"NativeUserDefinedFunction" ;
static std::unordered_map<std::wstring, std::pair<std::wstring, std::wstring>> s_deserializedUDFsRegistry;
Function::Function(const std::vector<Variable>& inputs, const std::wstring& name)
: Function(inputs, {}, name)
{}
Function::Function(const std::vector<Variable>& inputs, const Dictionary& functionConfig, const std::wstring& name)
: Function(inputs, functionConfig, nullptr, name, Internal::GenerateUid(L"UserDefinedFunction"))
{}
/*static*/ FunctionPtr Function::DeserializeNativeImpl(const std::vector<Variable>& inputs, const std::wstring& name, const Dictionary& dict)
{
static const vector<std::wstring> s_requiredDictionaryKeys = { userDefinedStateKey, udfModuleNameKey, udfFactoryMethodNameKey, opKey };
ValidateDictionary<PrimitiveFunction>(dict, s_requiredDictionaryKeys, s_nativeUDFTypeValue, s_serializationVersion);
auto state = dict[userDefinedStateKey].Value<Dictionary>();
auto opName = dict[opKey].Value<wstring>();
auto moduleName = dict[udfModuleNameKey].Value<wstring>();
auto methodName = dict[udfFactoryMethodNameKey].Value<wstring>();
FunctionPtr udf = nullptr;
auto callback = Function::GetUDFDeserializeCallback(opName);
if (callback != nullptr)
{
udf = callback->operator()(inputs, name, state);
}
else
{
Microsoft::MSR::CNTK::Plugin plugin;
auto factoryMethod = (UserFunctionFactoryMethodType)(plugin.Load(moduleName, msra::strfun::utf8(methodName), /*isCNTKPlugin =*/ false));
udf = FunctionPtr(factoryMethod(inputs.data(), inputs.size(), &state, name.c_str()));
}
if (udf == nullptr)
{
RuntimeError("Unable to reconstruct the native UserFunction with op name '%S'", opName.c_str());
}
s_deserializedUDFsRegistry[opName] = { moduleName, methodName };
return udf;
}
/*virtual*/ std::wstring Function::ModuleName() const
{
auto it = s_deserializedUDFsRegistry.find(OpName());
if (it != s_deserializedUDFsRegistry.end())
{
auto moduleAndMethodPair = it->second;
return moduleAndMethodPair.first;
}
// this op name was never registered in the s_deserializedUDFsRegistry (which only happens during the deserialization),
// then use user factory as a fallback (this udf must have been registed, so that an instance could be created).
return s_userFunctionFactory->GetModuleName(OpName());
}
/*virtual*/ std::wstring Function::DeserializeMethodName() const
{
auto it = s_deserializedUDFsRegistry.find(OpName());
if (it != s_deserializedUDFsRegistry.end())
{
auto moduleAndMethodPair = it->second;
return moduleAndMethodPair.second;
}
// this op name was never registered in the s_deserializedUDFsRegistry (which only happens during the deserialization),
// then use user factory as a fallback (this udf must have been registed, so that an instance could be created).
return Function::s_userFunctionFactory->GetFactoryMethodName(OpName());
}
Dictionary Function::SerializeNativeImpl() const
{
Dictionary dict;
dict[userDefinedStateKey] = Serialize();
dict[udfModuleNameKey] = ModuleName();
dict[udfFactoryMethodNameKey] = DeserializeMethodName();
dict[opKey] = OpName();
dict[versionKey] = s_serializationVersion;
dict[typeKey] = s_nativeUDFTypeValue;
return dict;
}
static const std::wstring s_userDefinedFunctionTypeValue = L"UserDefinedFunction";
/*static*/ bool UDFUtils::IsUDF(const FunctionPtr& f)
{
return (dynamic_cast<const PrimitiveFunction*>(f.get()) == nullptr) &&
(dynamic_cast<const CompositeFunction*>(f.get()) == nullptr);
}
/*static*/ bool UDFUtils::IsUDF(const Dictionary& dict)
{
return (dict.Contains(typeKey) && dict[typeKey].Value<std::wstring>() == s_userDefinedFunctionTypeValue);
}
/*static*/ bool UDFUtils::IsNativeUDF(const Dictionary& dict)
{
assert(IsUDF(dict));
return (dict.Contains(nativeUDFKey) && dict[nativeUDFKey].Value<bool>() == true);
}
/*static*/ Dictionary UDFUtils::Serialize(const FunctionPtr& f)
{
Dictionary dict = SerializeCommonFunctionAttributes(*f, s_serializationVersion, s_userDefinedFunctionTypeValue);
bool native = f->IsNative();
dict[nativeUDFKey] = native;
dict[userDefinedStateKey] = (native) ? f->SerializeNativeImpl() : f->Serialize();
return dict;
}
/*static*/ FunctionPtr UDFUtils::Deserialize(const Dictionary& dict,
const unordered_map<std::wstring, Variable>& uidToVariableMap,
const DeviceDescriptor& device)
{
static const vector<std::wstring> s_requiredDictionaryKeys = { typeKey, uidKey, inputsKey, userDefinedStateKey };
ValidateDictionary<PrimitiveFunction>(dict, s_requiredDictionaryKeys, s_userDefinedFunctionTypeValue, s_serializationVersion);
const auto& uid = dict[uidKey].Value<std::wstring>();
std::wstring name = L"";
if (dict.Contains(nameKey))
name = dict[nameKey].Value<std::wstring>();
auto inputs = GetInputVariables(dict, uidToVariableMap, s_serializationVersion);
auto state = dict[userDefinedStateKey].Value<Dictionary>();
FunctionPtr udf;
if (IsNativeUDF(dict))
{
udf = Function::DeserializeNativeImpl(inputs, name, state);
}
else if (s_SWIGCallbackWrapper != nullptr)
{
// If we're being called from SWIG, the actual deserializer should be registered by
// the target language CNTK implementation (i.e., cnkt_py for Python)
udf = s_SWIGCallbackWrapper->operator()(inputs, name, state);
}
if (udf == nullptr)
{
RuntimeError("Unable to reconstruct a user-defined function (name = %S, uid = %S). "
"Please make sure to specify a valid UDF deserializer.", name.c_str(), uid.c_str());
}
// Restore the original uid, which other functions in the graph depend on
// (their inputs refer to the uids of this UDF outputs, which are generated base on the uid of this UDF).
udf->m_uid = uid;
return udf;
}
}

Просмотреть файл

@ -0,0 +1,32 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
#pragma once
#include "stdafx.h"
#include "CNTKLibrary.h"
namespace CNTK
{
class UDFUtils
{
public:
static bool IsUDF(const FunctionPtr& f);
static bool IsUDF(const Dictionary& dict);
static Dictionary Serialize(const FunctionPtr& f);
static FunctionPtr Deserialize(const Dictionary& dictionary,
const std::unordered_map<std::wstring, Variable>& uidToVariableMap,
const CNTK::DeviceDescriptor& device);
static const size_t s_serializationVersion = 0;
private:
static bool IsNativeUDF(const Dictionary& dict);
};
}

Просмотреть файл

@ -11,30 +11,60 @@
namespace CNTK
{
typedef Function* (*UserFunctionFactoryMethodType)(const Variable* operands, size_t numOperands, const Dictionary* attributes, const wchar_t* name);
class UserFunctionFactory : public std::enable_shared_from_this<UserFunctionFactory>
{
typedef Function* (*UserFunctionFactoryMethodType)(const Variable* operands, size_t numOperands, const Dictionary* attributes, const wchar_t* name);
public:
bool IsRegistered(const std::wstring& uniqueOpName)
{
std::unique_lock<std::mutex> lock(m_mutex);
return IsRegisteredImpl(uniqueOpName);
}
void Register(const std::wstring& uniqueOpName, const std::wstring& moduleName, const std::wstring& factoryMethodName)
{
std::unique_lock<std::mutex> lock(m_mutex);
if (m_userFunctionFactoryMethodMap.find(uniqueOpName) != m_userFunctionFactoryMethodMap.end())
if (IsRegisteredImpl(uniqueOpName))
InvalidArgument("UserFunction with op name '%S' is already registered. All UserFunction op names must be unique.", uniqueOpName.c_str());
m_userFunctionFactoryMethodMap[uniqueOpName] = std::make_shared<UserFunctionFactoryMethodInfo>(moduleName, factoryMethodName);
}
std::wstring GetModuleName(const std::wstring& uniqueOpName)
{
std::unique_lock<std::mutex> lock(m_mutex);
if (!IsRegisteredImpl(uniqueOpName))
InvalidArgument("UserFunction with op name '%S' has not been registered.", uniqueOpName.c_str());
return m_userFunctionFactoryMethodMap.at(uniqueOpName)->m_moduleName;
}
std::wstring GetFactoryMethodName(const std::wstring& uniqueOpName)
{
std::unique_lock<std::mutex> lock(m_mutex);
if (!IsRegisteredImpl(uniqueOpName))
InvalidArgument("UserFunction with op name '%S' has not been registered.", uniqueOpName.c_str());
return m_userFunctionFactoryMethodMap.at(uniqueOpName)->m_factoryMethodName;
}
FunctionPtr CreateInstance(const std::wstring& opName, const std::vector<Variable>& inputs, const Dictionary& functionConfig, const std::wstring& userFunctionInstanceName)
{
std::unique_lock<std::mutex> lock(m_mutex);
if (m_userFunctionFactoryMethodMap.find(opName) == m_userFunctionFactoryMethodMap.end())
if (!IsRegisteredImpl(opName))
InvalidArgument("UserFunction with op name '%S' has not been registered.", opName.c_str());
return std::shared_ptr<Function>(m_userFunctionFactoryMethodMap.at(opName)->m_factoryMethod(inputs.data(), inputs.size(), &functionConfig, userFunctionInstanceName.c_str()));
}
private:
bool IsRegisteredImpl(const std::wstring& uniqueOpName)
{
return (m_userFunctionFactoryMethodMap.find(uniqueOpName) != m_userFunctionFactoryMethodMap.end());
}
struct UserFunctionFactoryMethodInfo : public std::enable_shared_from_this<UserFunctionFactoryMethodInfo>
{
UserFunctionFactoryMethodInfo(const std::wstring& moduleName, const std::wstring& factoryMethodName)
@ -53,4 +83,4 @@ namespace CNTK
std::mutex m_mutex;
std::unordered_map<std::wstring, UserFunctionFactoryMethodInfoPtr> m_userFunctionFactoryMethodMap;
};
}
}

Просмотреть файл

@ -1927,25 +1927,34 @@ SWIG_STD_VECTOR_ENHANCED(CNTK::DeviceDescriptor)
%ignore CNTK::Function::Load(const std::wstring& filepath, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice(), const Internal::UDFDeserializerPtr& deserializer);
%ignore CNTK::Function::Load(const char* buffer, size_t length, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice(), const Internal::UDFDeserializerPtr& deserializer = nullptr);
%ignore CNTK::Function::Load(const std::wstring& filepath, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
%ignore CNTK::Function::Load(const char* buffer, size_t length, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
// Ignore exposing istream to C# for now. Todo: find a good solution to map C# System.IO.Stream to std::istream.
%ignore CNTK::Function::Load(std::istream& inputStream, const DeviceDescriptor& computeDevice= DeviceDescriptor::UseDefaultDevice(), const Internal::UDFDeserializerPtr& deserializer = nullptr);
%ignore CNTK::Function::Load(std::istream& inputStream, const DeviceDescriptor& computeDevice= DeviceDescriptor::UseDefaultDevice());
%extend CNTK::Function {
static FunctionPtr Load(const std::wstring& filepath,
const CNTK::DeviceDescriptor& computeDevice = CNTK::DeviceDescriptor::UseDefaultDevice())
{
return CNTK::Function::Load(filepath, computeDevice, nullptr);
return CNTK::Function::Load(filepath, computeDevice);
}
static FunctionPtr Load(const char* modelBuffer, size_t length,
const CNTK::DeviceDescriptor& computeDevice = CNTK::DeviceDescriptor::UseDefaultDevice())
{
return CNTK::Function::Load(modelBuffer, length, computeDevice, nullptr);
return CNTK::Function::Load(modelBuffer, length, computeDevice);
}
}
%ignore CNTK::Function::RegisterUDFDeserializeCallback;
%ignore CNTK::Function::GetUDFDeserializeCallback;
%ignore_class CNTK::Internal::UDFDeserializeCallbackWrapper;
%ignore_function CNTK::Internal::RegisterUDFDeserializeCallbackWrapper;
%ignore_function CNTK::Internal::IsNativeUserFunctionRegistered;
%include "CNTKLibraryInternals.h"
%include "CNTKLibrary.h"

Просмотреть файл

@ -21,7 +21,7 @@
%rename(_add_progress_writers) CNTK::Internal::AddProgressWriters;
%rename(_backward) CNTK::Function::Backward;
%rename(_infer_outputs) CNTK::Function::InferOutputs;
%rename(_serialize) CNTK::Function::Serialize;
%rename(_serialize_impl) CNTK::Function::Serialize;
%rename(_deserialize) CNTK::Function::Deserialize;
%rename(_update) CNTK::Learner::Update;
%rename(sgd_learner) CNTK::SGDLearner;
@ -39,9 +39,9 @@
%rename(ctf_deserializer) CNTK::CTFDeserializer;
%rename(htk_feature_deserializer) CNTK::HTKFeatureDeserializer;
%rename(htk_mlf_deserializer) CNTK::HTKMLFDeserializer;
%rename(_infer_outputs) CNTK::Function::InferOutputs;
%rename(_stream_infos) CNTK::SwigMinibatchSource::StreamInfos(PyObject*);
%rename(_next_minibatch) CNTK::SwigMinibatchSource::_GetNextMinibatch;
%rename(_register_udf_deserialize_callback) CNTK::Internal::RegisterUDFDeserializeCallbackWrapper;
%rename(_none) CNTK::DictionaryValue::Type::None;
@ -174,7 +174,8 @@
%ignore CNTK::Internal::TensorBoardFileWriter::TensorBoardFileWriter(const std::wstring& dir, const ::Microsoft::MSR::CNTK::ComputationNetworkPtr& modelToVisualize = nullptr);
%ignore CNTK::Internal::Convolution;
%ignore CNTK::Function::Function(const std::vector<Variable>& inputs, Dictionary&& functionConfig, const std::wstring& name = L"", const std::wstring& uid = Internal::GenerateUid(L"UserDefinedFunction"));
%ignore CNTK::Function::RegisterUDFDeserializeCallback;
%ignore CNTK::Function::GetUDFDeserializeCallback;
%{
#define SWIG_FILE_WITH_INIT
@ -617,17 +618,32 @@ public:
}
}
// Callback support
%feature("director") CNTK::Function;
%feature("director") CNTK::Internal::UDFDeserializeCallbackWrapper;
%feature("nodirector") CNTK::Function::OnPlaceholdersReplaced;
%feature("nodirector") CNTK::Function::OpName;
// Callback support
%feature("director") CNTK::Internal::UDFDeserializeCallbackWrapper;
%typemap(directorout) std::shared_ptr<CNTK::Function> (void * swig_argp, int swig_res = 0) {
if ($input == Py_None) {
$result = $ltype();
} else {
swig_res = SWIG_ConvertPtr($input, &swig_argp, $descriptor(std::shared_ptr<CNTK::Function> *), %convertptr_flags);
if (!SWIG_IsOK(swig_res)) {
%dirout_fail(swig_res,"$type");
}
$result = *(%reinterpret_cast(swig_argp, $&ltype));
}
}
// Since there're three overloads of Function::Load, both "rename" and "compactdefaultargs" are needed
// to make pybuffer_binary work correctly with default arguments.
%rename(load_from_buffer) CNTK::Function::Load(const char*, size_t, const DeviceDescriptor&);
%feature("compactdefaultargs") CNTK::Function::Load(const char*, size_t, const DeviceDescriptor&);
// This overload is not used in python at the moment.
%ignore CNTK::Function::Load(const std::wstring&, const DeviceDescriptor&, const UDFDeserializeCallback&);
%ignore CNTK::Function::Load(const char*, size_t, const DeviceDescriptor&, const UDFDeserializeCallback&);
%ignore CNTK::Function::Load(std::istream&, const DeviceDescriptor&, const UDFDeserializeCallback&);
%rename(load_from_buffer) CNTK::Function::Load(const char*, size_t, const CNTK::DeviceDescriptor&, const CNTK::Internal::UDFDeserializeCallbackWrapper&);
%ignore CNTK::Function::Load(std::istream&, const DeviceDescriptor&);
%feature("director") CNTK::Learner;
%feature("nodirector") CNTK::Learner::Parameters;
@ -812,7 +828,7 @@ public:
}
%enddef
// Implementing typemapping for UDFDeserializeCallback (it has a dictionary
// Implementing typemapping for UDFDeserializeCallbackWrapper (it has a dictionary
// as one of its input parameters), which needs to be implemented in Python.
%typemap(directorin, fragment="DictionaryValueToPy") const CNTK::Dictionary&
@ -1403,6 +1419,7 @@ std::unordered_map<CNTK::StreamInformation, std::pair<CNTK::NDArrayViewPtr, CNTK
%shared_ptr(CNTK::DistributedLearner)
%shared_ptr(CNTK::Internal::TensorBoardFileWriter)
%shared_ptr(CNTK::ProgressWriter)
%shared_ptr(CNTK::Internal::UDFDeserializeCallbackWrapper)
%include "CNTKLibraryInternals.h"
%include "CNTKLibrary.h"
@ -1784,6 +1801,7 @@ namespace CNTK
%template(random_uniform_float) CNTK::NDArrayView::RandomUniform<float>;
%template(random_uniform_double) CNTK::NDArrayView::RandomUniform<double>;
%template(DictionaryValueFromDict) CNTK::DictionaryValue::DictionaryValue<CNTK::Dictionary>;
%template(DictionaryValueFromNDArrayView) CNTK::DictionaryValue::DictionaryValue<CNTK::NDArrayView>;
%template(training_parameter_per_sample_schedule) CNTK::TrainingParameterPerUnitSchedule<double, CNTK::TrainingParameterSchedule<double>::UnitType::Sample>;
%template(training_parameter_per_minibatch_schedule) CNTK::TrainingParameterPerUnitSchedule<double, CNTK::TrainingParameterSchedule<double>::UnitType::Minibatch>;

Просмотреть файл

@ -23,31 +23,38 @@ def _value_as_sequence_or_array(val, var):
map_if_possible(val)
return val.asarray()
_serialization_version = 1
def _serialize(udf):
dictionary = {}
dictionary['class'] = udf.__class__.__name__
dictionary['module'] = udf.__class__.__module__
dictionary['op_name'] = udf.op_name
dictionary['state'] = udf.serialize()
dictionary['version'] = _serialization_version
return dictionary
class _UDFDeserializeCallbackWrapper(cntk_py.UDFDeserializeCallbackWrapper):
'''
Provides an implementation of the UDFDeserializer interface used
to inflate user defined functions in a model dictionary.
'''
def __init__(self, factory_callback_map=None):
super(_UDFDeserializeCallbackWrapper, self).__init__()
self.factory_callback_map = factory_callback_map
self.__disown__()
def __call__(self, inputs, name, dictionary):
import pdb; pdb.set_trace()
cls = dictionary['class']
module = dictionary['module']
state = dictionary['state']
op_name = dictionary['op_name']
deserialize_method = 'deserialize'
if (self.factory_callback_map and op_name in self.factory_callback_map):
factory = self.factory_callback_map[op_name]
else:
exec("from {} import {}".format(module, cls))
eval_str = "{0}.deserialize if hasattr({0}, 'deserialize') else None"
factory = eval(eval_str.format(cls))
eval_str = "{0}.{1} if hasattr({0}, '{1}') else None"
factory = eval(eval_str.format(cls, deserialize_method))
if (factory):
if factory:
return factory(list(inputs), name, state)
raise ValueError("Cannot deserialize user function '{}.{}'. "

Просмотреть файл

@ -18,6 +18,15 @@ def is_string(s):
'''
return isinstance(s, ("".__class__, u"".__class__))
def is_byte_buffer(s):
'''
Tests whether ``s`` is a byte buffer (not a string) in a way that
works on Python 2 and 3.
'''
return (isinstance(s, bytearray) or
(isinstance(s, type(b'')) and not isinstance(b'', str)))
def _as_tuple(x):
'''
Convert an argument to a tuple.

Просмотреть файл

@ -6,6 +6,8 @@
from .. import cntk_py
import numpy as np
from cntk import NDArrayView
from ..cntk_py import DictionaryValueFromDict, DictionaryValue, Dictionary, DictionaryValueFromNDArrayView
_VARIABLE_OR_FUNCTION = (cntk_py.Variable, cntk_py.Function)
@ -206,7 +208,6 @@ def _py_dict_to_cntk_dict(py_dict):
cntk_py.Dictionary:
A :class:`~cntk.cntk_py.Dictionary` that has been converted from the input `dict`
'''
from ..cntk_py import DictionaryValueFromDict, DictionaryValue, Dictionary
def _to_cntk_dict_value(py_value):
if isinstance(py_value, dict):
return DictionaryValueFromDict(_py_dict_to_cntk_dict(py_value))
@ -214,7 +215,14 @@ def _py_dict_to_cntk_dict(py_dict):
if isinstance(py_value, list):
py_list = list(map(_to_cntk_dict_value, py_value))
return DictionaryValue(py_list)
if isinstance(py_value, np.ndarray):
py_value = NDArrayView.from_dense(py_value)
return DictionaryValueFromNDArrayView(py_value)
if py_value is None:
return DictionaryValue()
return DictionaryValue(py_value)
res = Dictionary()

Просмотреть файл

@ -13,7 +13,8 @@ from cntk.internal import map_if_possible, typemap, sanitize_var_map,\
_value_as_sequence_or_array
from cntk.internal.utils import get_python_function_arguments, \
map_function_arguments, _py_dict_to_cntk_dict
from cntk.internal import _UDFDeserializeCallbackWrapper
from cntk.internal import _UDFDeserializeCallbackWrapper, _serialize
from cntk.internal.sanitize import is_byte_buffer
from ..variables import Record, Variable
@ -90,6 +91,10 @@ class Function(cntk_py.Function):
in :func:`~cntk.ops.as_block()`, using the supplied ``op_name`` and ``name`` parameters, which are otherwise ignored.
'''
_udf_callback_map = {}
_deserializer = _UDFDeserializeCallbackWrapper(_udf_callback_map)
cntk_py._register_udf_deserialize_callback(_deserializer)
# We override the constructors to implement an overload that constructs
# a CNTK Functions from a Python function (@Function).
def __new__(cls, *args, **kwargs):
@ -1135,25 +1140,43 @@ class Function(cntk_py.Function):
'restore(...) instead', DeprecationWarning)
return self.restore(filename)
@staticmethod
def register_udf_deserialize_callback(op_name, callback):
'''
Register a callback function to be invoked when deserializing a user-
defined function with the corresponding op name.
When loading a model, CNTK will try to automatically reconstruct any
(non-native) user-defined functions by invoking a static
:func:`~cntk.ops.functions.UserFunction.deserialize` method of the
corresponding UserFunction sub-class. This method allows to override
default UDF deserialization behavior by specifying a user- defined
function op name and the corresponding callback that should be invoked
instead of the ``deserialize`` method.
Args:
op_name (str): unique op name of the user-defined function.
callback (function): a function taking three arguments (a list of
inputs to the UserFunction, a string name, and a state dictionary
generated by the corresponding :func:`~cntk.ops.functions.UserFunction.serialize`
method) and returns an instance of the user-defined function.
'''
if op_name in Function._udf_callback_map:
raise ValueError("A callback for the UserFunction with op name {}"
" has already been registered.".format(op_name));
Function._udf_callback_map[op_name] = callback
@staticmethod
@typemap
def load(model, device=None, udf_factory_callback_map=None):
def load(model, device=None):
'''
Load the ``model``, that has been saved using :func:`~cntk.ops.functions.Function.save`.
Args:
model (str or bytes): either a filepath of a model file or a byte buffer
model (str, bytes or bytearray): either a file path of a model file or a byte buffer
containing the binary representation of a model.
device (:class:`~cntk.device.DeviceDescriptor`, defaults to the current globally default device):
specifies the device to allocate the model on.
udf_factory_callback_map (dict, default is `None`): if the model contains any user-defined
functions, CNTK will try to automatically reconstruct them by invoking a static
``deserialize`` method of the corresponding Function sub-class. This method takes three
arguments (a list of inputs to the function, a string name, and a state dictionary
generated by the corresponding :func:`~cntk.ops.functions.UserFunction.serialize` method) and
returns an instance of the user-defined function. This optional argument allows to override
default UDF deserialization behavior by providing a map of user-function op names and
corresponding lambdas that should be invoked instead of the ``deserialize`` method.
Returns:
root node
@ -1161,10 +1184,7 @@ class Function(cntk_py.Function):
if not device:
device = DeviceDescriptor.use_default_device()
deserializer = _UDFDeserializeCallbackWrapper(udf_factory_callback_map)
is_buffer = isinstance(model, type(b'')) and not isinstance(b'', str)
is_buffer = is_buffer or isinstance(model, bytearray)
is_buffer = is_byte_buffer(model)
is_file = False
if not is_buffer:
@ -1174,10 +1194,10 @@ class Function(cntk_py.Function):
pass
if is_buffer:
return cntk_py.Function.load_from_buffer(model, device, deserializer)
return cntk_py.Function.load_from_buffer(model, device)
if is_file:
return cntk_py.Function.load(model, device, deserializer)
return cntk_py.Function.load(model, device)
raise ValueError('Cannot load a model that is neither a file nor a byte buffer.')
@ -1225,11 +1245,11 @@ def native_user_function(op_name, operands, attributes=None, user_function_insta
return cntk_py.Function_native_user_function(op_name, operands, attributes, user_function_instance_name)
@typemap
def load_model(model, device=None, udf_factory_callback_map=None):
def load_model(model, device=None):
'''
Alias for :func:`~cntk.ops.functions.Function.load`.
'''
return Function.load(model, device, udf_factory_callback_map)
return Function.load(model, device)
@typemap
def save_model(model, filename): # legacy name
@ -1251,8 +1271,10 @@ class UserFunction(Function):
`False` passes the data as CNTK Value objects.
name (str): name of this function
'''
def __init__(self, inputs, as_numpy=True, name=''):
super(UserFunction, self).__init__(inputs, name)
self.set_native(False)
self.as_numpy = as_numpy
# Since the state will frequently not be used, we cache the None-state
@ -1402,14 +1424,36 @@ class UserFunction(Function):
'''
raise NotImplementedError('clone has to be overwritten')
def _serialize(self):
dictionary = {}
dictionary['class'] = self.__class__.__name__
dictionary['module'] = self.__class__.__module__
dictionary['op_name'] = self.op_name
dictionary['state'] = self.serialize()
def _serialize_impl(self):
dictionary = _serialize(self)
return _py_dict_to_cntk_dict(dictionary)
@staticmethod
def deserialize(inputs, name, state):
'''
A stub deserialize method for illustration purposes. User-defined functions
need to provide their own implementation in order for CNTK to be able to
reconstruct them when loading a model.
Args:
inputs (list): a list of inputs to the function
name (str): name of this function
state (dict): a state dictionary generated by the corresponding
:func:`~cntk.ops.functions.UserFunction.serialize` method.
Returns:
An instance of the user-defined function.
'''
raise NotImplementedError('a stub method for illustration purposes.')
@property
def op_name(self):
'''
Unique operation name of this user-defined function.
This property defaults to '<module>.<class>', but can be overridden.
'''
return self.__class__._op_name()
def serialize(self):
'''
Generates a dictionary that captures the state of this user-defined function.
@ -1418,3 +1462,7 @@ class UserFunction(Function):
to be preserved in the model dictionary.
'''
return {}
@classmethod
def _op_name(cls):
return cls.__module__ + '.' + cls.__name__

Просмотреть файл

@ -49,10 +49,8 @@ class MyPlus(UserFunction):
input_gradients[input] = root_gradients
def serialize(self):
internal_state = {}
internal_state['forward_calls'] = self.forward_calls
internal_state['backward_calls'] = self.backward_calls
return internal_state
return {'forward_calls' : self.forward_calls,
'backward_calls' : self.backward_calls}
@staticmethod
def deserialize(inputs, name, state):
@ -61,6 +59,24 @@ class MyPlus(UserFunction):
f.backward_calls = state['backward_calls']
return f
class MyPlusPlus(MyPlus):
def __init__(self, inputs, name, state={}):
super(MyPlusPlus, self).__init__(inputs[0], inputs[1], name=name+name)
def forward(self, *args, **kwargs):
r1 = super(MyPlusPlus, self).forward(*args, **kwargs)
r2 = super(MyPlusPlus, self).forward(*args, **kwargs)
return None, r1[1] + r2[1]
def serialize(self):
return None
@staticmethod
def deserialize(*args):
return MyPlusPlus(*args)
def test_ext_eval_1():
dim = 4
p = parameter(shape=(dim,), init=10, name='p')
@ -315,12 +331,14 @@ def test_ext_lambdafunc(tmpdir):
filepath = str(tmpdir / 'test_ext_lambdafunc.dat')
z0.save(filepath)
z = Function.load(filepath, udf_factory_callback_map =
{
'conditional_exec_lambda' :
lambda x, *unused:
LambdaFunc(x, when=lambda arg: np.sum(arg)>1, execute=cb.inc)
})
Function.register_udf_deserialize_callback(
'conditional_exec_lambda',
lambda x, *unused:
LambdaFunc(x, when=lambda arg: np.sum(arg)>1, execute=cb.inc))
z = Function.load(filepath)
momentum_time_constant = momentum_as_time_constant_schedule(1100)
lr_per_sample = learning_rate_schedule(0.007, UnitType.sample)
trainer = Trainer(z, (z+0, z+0), \
@ -484,18 +502,136 @@ def test_udf_output_needs_no_gradient():
assert np.allclose(result, [[[6., 8.], [10., 12.]]])
def test_native_user_function():
ops.register_native_user_function('NativeUserTimesFunction', 'Cntk.ExtensibilityExamples-' + C.__version__.rstrip('+'), 'CreateUserTimesFunction')
def test_native_user_function(tmpdir):
if not cntk_py.is_native_user_function_registered('NativeUserTimesOp'):
ops.register_native_user_function('NativeUserTimesOp', 'Cntk.ExtensibilityExamples-' + C.__version__.rstrip('+'), 'CreateUserTimesFunction')
dev = cpu()
x = input((2))
w = parameter((2, 2), init=np.asarray([[0.5, 2], [-0.5, 1.5]], dtype=np.float32), device=dev)
attributes = {'param_rank' : 2, 'padding' : True}
op = ops.native_user_function('NativeUserTimesFunction', [w, x], attributes, 'native_user_times_function')
attributes = {
'param_rank' : 2,
'padding' : True,
'none': None,
'nested lists' : [[1,2,3], [4,5,6]],
'string' : 'string',
'some data' : np.arange(1, 10, dtype=np.float32).reshape((3,3))
}
def verify_attributes(udf):
for k,v in attributes.items():
if not isinstance(v, np.ndarray):
assert udf.attributes[k] == v
else:
assert (udf.attributes[k] == v).all()
op = ops.native_user_function('NativeUserTimesOp', [w, x], attributes, 'native_user_times_function')
verify_attributes(op.owner)
filepath = str(tmpdir / 'test_native_user_function.dat')
op.save(filepath)
op_reloaded = Function.load(filepath, device=dev)
x_data = NDArrayView.from_dense(np.asarray([[0.1, 0.2], [-0.1, 0.3]], dtype=np.float32), device=dev)
result = op.eval({x : x_data}, device=dev)
result = op_reloaded.eval({op_reloaded.arguments[0] : x_data}, device=dev)
assert np.allclose(result, [[-0.05, 0.5], [-0.2, 0.25]])
native_times_primitive = op.find_by_name('native_user_times_function')
assert native_times_primitive.attributes == attributes
native_times_primitive = op_reloaded.find_by_name('native_user_times_function')
verify_attributes(native_times_primitive)
def build_test_function():
dev = cpu()
w_value = np.asarray([[0.5, 2], [-0.5, 1.5]]).astype(np.float32)
c1_value = 2.718
c2_value = -3.141
if not cntk_py.is_native_user_function_registered('NativeUserTimesOp'):
ops.register_native_user_function('NativeUserTimesOp', 'Cntk.ExtensibilityExamples-' + C.__version__.rstrip('+'), 'CreateUserTimesFunction')
x = input((2))
w = parameter((2, 2), init=w_value, device=dev)
mp = MyPlus(x, constant(c1_value))
op = user_function(MyPlus(x, constant(c1_value)))
op = ops.native_user_function('NativeUserTimesOp', [w, op], user_function_instance_name='my_times')
return dev, w_value, c1_value, c2_value, user_function(MyPlus(op, constant(c2_value)))
def test_both_flavors_of_user_functions(tmpdir):
dev, w_value, c1_value, c2_value, op = build_test_function()
filepath = str(tmpdir / 'test_native_user_function.dat')
op.save(filepath)
op_reloaded = Function.load(filepath, device=dev)
np.random.seed(1)
for i in range(5):
x_value = np.random.random((2,2)).astype(np.float32)
x_data = NDArrayView.from_dense(x_value, device=dev)
result = op_reloaded.eval({op_reloaded.arguments[0] : x_data}, device=dev)
expected = np.matmul((x_value + c1_value), w_value) + c2_value
assert np.allclose(result, expected)
def test_udf_checkpointing(tmpdir):
dev, w_value, c1_value, c2_value, op = build_test_function()
label = constant(np.asarray([[1, 2], [3,4]]).astype(np.float32))
loss = cross_entropy_with_softmax(op, label)
eval_error = classification_error(op, label)
lr_schedule = learning_rate_schedule(0.5, UnitType.minibatch)
learner = sgd(op.parameters, lr_schedule)
trainer = Trainer(op, (loss, eval_error), [learner])
trainer.train_minibatch({op.arguments[0] : np.random.random((2,2)).astype(np.float32)}, device=dev)
filepath = str(tmpdir /'test_checkpointing.out')
trainer.save_checkpoint(filepath, external_state={'test':'test'})
d = cntk_py.Dictionary.load(filepath)
assert len(d.keys()) != 0
def test_override_deserialize(tmpdir):
dev, w_value, c1_value, c2_value, op = build_test_function()
filepath = str(tmpdir / 'test_override_deserialize.dat')
op.save(filepath)
Function.register_udf_deserialize_callback(
MyPlus._op_name(), lambda *x : MyPlusPlus(*x))
op_reloaded = Function.load(filepath, device=dev)
np.random.seed(1)
for i in range(5):
x_value = np.random.random((2,2)).astype(np.float32)
x_data = NDArrayView.from_dense(x_value, device=dev)
result = op_reloaded.eval({op_reloaded.arguments[0] : x_data}, device=dev)
expected = 2 *(np.matmul(2 * (x_value + c1_value), w_value) + c2_value)
assert np.allclose(result, expected)
def test_override_serialize(tmpdir):
dev=cpu()
a,b = 1.2322341, -0.29084;
op = MyPlusPlus([constant(a), constant(b)], '++');
op = MyPlusPlus([op, op], '+++')
op = MyPlusPlus([op, op], '++++')
op = user_function(op)
result1 = op.eval({}, device=dev)
filepath = str(tmpdir / 'test_udf_with_renamed_deserialize.dat')
op.save(filepath)
op_reloaded = Function.load(filepath, device=dev)
assert result1 == op_reloaded.eval({}, device=dev)

Просмотреть файл

@ -14,6 +14,7 @@ Implementing a custom operator in pure Python is simple matter of
- implementing ``forward()`` and ``backward()``, whose signatures dependent on the number of inputs and outputs
- specifying the outputs' shape, data type and dynamic axes in
``infer_outputs()``
- providing a static ``deserialize()`` method to inflate previously saved function
In the simplest case, just only one input and output, ``forward()`` takes an
argument and returns a tuple of a state and the result. The state can be used to
@ -46,6 +47,10 @@ tuple, strings, etc.)::
return [output_variable(self.inputs[0].shape, self.inputs[0].dtype,
self.inputs[0].dynamic_axes)]
@staticmethod
def deserialize(inputs, name, state):
return = MySigmoid(inputs[0], name)
This can now be used as a normal operator like::
from cntk import user_function
@ -61,13 +66,17 @@ In case, the operator is initialized with multiple inputs, ``forward()`` 's
class MyPlus(UserFunction):
def __init__(self, arg1, arg2, name='f1'):
super(MyPlus, self).__init__([arg1, arg2], name=name)
self.forward_calls = 0
self.backward_calls = 0
def forward(self, arguments, device=None, outputs_to_retain=None):
# No state needs to be passed to backward() so we just
# pass None
self.forward_calls += 1
return None, arguments[0] + arguments[1]
def backward(self, state, root_gradients):
self.backward_calls += 1
return root_gradients
def infer_outputs(self):
@ -76,6 +85,17 @@ In case, the operator is initialized with multiple inputs, ``forward()`` 's
# result would actually look like (considering broadcasting, etc.).
return [output_variable(self.inputs[0].shape, self.inputs[0].dtype, self.inputs[0].dynamic_axes)]
def serialize(self):
return {'forward_calls' : self.forward_calls,
'backward_calls' : self.backward_calls}
@staticmethod
def deserialize(inputs, name, state):
f = MyPlus(inputs[0], inputs[1], name)
f.forward_calls = state['forward_calls']
f.backward_calls = state['backward_calls']
return f
If the UserFunction has more than one input, ``backward()`` is invoked
with an additional ``variables`` argument, which contains a dictionary of
Variable to the gradient data, whose values have to be set with the proper
@ -101,6 +121,13 @@ have to be set with the proper data.
In addition, ``root_gradient`` in ``backward()`` is a dictionary of Variable to the
root_gradient.
``deserialize()`` is invoked by CNTK to reconstruct a previously saved function. It should
have the same signature as :func:`~cntk.ops.functions.UserFunction.deserialize` method.
In case of a stateless function, it simply needs to invoke the constructor and return an
instance of the user function. However, if the function is stateful and overrides
:func:`~cntk.ops.functions.UserFunction.serialize` method, ``deserialize()`` also needs to
properly restore the function state.
Using user functions for debugging
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~