CNTK/bindings/python/cntk/internal/swig_helper.py

73 строки
2.9 KiB
Python

# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license. See LICENSE.md file in the project root
# for full license information.
# ==============================================================================
from functools import wraps
from .. import cntk_py
_typemap = None
def map_if_possible(obj):
global _typemap
if _typemap is None:
# We can do this only if cntk_py and the cntk classes are already
# known, which is the case, when map_if_possible is called.
from cntk import Value, NDArrayView
from cntk.axis import Axis
from cntk.device import DeviceDescriptor
from cntk.io import MinibatchSource, MinibatchData, StreamConfiguration
from cntk.learners import Learner
from cntk.ops.functions import Function
from cntk.train.trainer import Trainer
from cntk.train.training_session import TrainingSession
from cntk.train.distributed import WorkerDescriptor, Communicator,\
DistributedLearner
from cntk.variables import Variable, Parameter, Constant
_typemap = {
cntk_py.Axis: Axis,
cntk_py.Constant: Constant,
cntk_py.DeviceDescriptor: DeviceDescriptor,
cntk_py.DistributedWorkerDescriptor: WorkerDescriptor,
cntk_py.DistributedCommunicator: Communicator,
cntk_py.DistributedLearner: DistributedLearner,
cntk_py.Function: Function,
cntk_py.Learner: Learner,
cntk_py.MinibatchData: MinibatchData,
cntk_py.MinibatchSource: MinibatchSource,
cntk_py.NDArrayView: NDArrayView,
cntk_py.Parameter: Parameter,
cntk_py.StreamConfiguration: StreamConfiguration,
cntk_py.Trainer: Trainer,
cntk_py.TrainingSession: TrainingSession,
cntk_py.Value: Value,
cntk_py.Variable: Variable,
}
# Some types like NumPy arrays don't let to set the __class__
if obj.__class__ in _typemap:
obj.__class__ = _typemap[obj.__class__]
else:
if isinstance(obj, (tuple, list, set)):
for o in obj:
map_if_possible(o)
elif isinstance(obj, dict):
for k, v in obj.items():
map_if_possible(k)
map_if_possible(v)
def typemap(f):
'''
Decorator that upcasts return types from Swig types to cntk types that
inherit from Swig. It does so recursively, e.g. if the return type is a
tuple containing a dictionary, it will try to upcast every element in the
tuple and all the keys and values in the dictionary.
'''
@wraps(f)
def wrapper(*args, **kwds):
result = f(*args, **kwds)
map_if_possible(result)
return result
return wrapper