151 строка
5.3 KiB
Python
151 строка
5.3 KiB
Python
# -*- coding: utf-8 -*-
|
|
# File: registry.py
|
|
|
|
|
|
import copy
|
|
import re
|
|
from functools import wraps
|
|
import six
|
|
import tensorflow as tf
|
|
|
|
from ..compat import tfv1
|
|
from ..tfutils.argscope import get_arg_scope
|
|
from ..tfutils.model_utils import get_shape_str
|
|
from ..utils import logger
|
|
|
|
# make sure each layer is only logged once
|
|
_LAYER_LOGGED = set()
|
|
_LAYER_REGISTRY = {}
|
|
|
|
__all__ = ['layer_register']
|
|
|
|
|
|
def _register(name, func):
|
|
if name in _LAYER_REGISTRY:
|
|
raise ValueError("Layer named {} is already registered!".format(name))
|
|
if name in ['tf']:
|
|
raise ValueError(logger.error("A layer cannot be named {}".format(name)))
|
|
_LAYER_REGISTRY[name] = func
|
|
|
|
# handle alias
|
|
if name == 'Conv2DTranspose':
|
|
_register('Deconv2D', func)
|
|
|
|
|
|
def get_registered_layer(name):
|
|
"""
|
|
Args:
|
|
name (str): the name of the layer, e.g. 'Conv2D'
|
|
Returns:
|
|
the wrapped layer function, or None if not registered.
|
|
"""
|
|
return _LAYER_REGISTRY.get(name, None)
|
|
|
|
|
|
def disable_layer_logging():
|
|
"""
|
|
Disable the shape logging for all layers from this moment on. Can be
|
|
useful when creating multiple towers.
|
|
"""
|
|
class ContainEverything:
|
|
def __contains__(self, x):
|
|
return True
|
|
# can use nonlocal in python3, but how
|
|
globals()['_LAYER_LOGGED'] = ContainEverything()
|
|
|
|
|
|
def layer_register(
|
|
log_shape=False,
|
|
use_scope=True):
|
|
"""
|
|
Args:
|
|
log_shape (bool): log input/output shape of this layer
|
|
use_scope (bool or None):
|
|
Whether to call this layer with an extra first argument as scope.
|
|
When set to None, it can be called either with or without
|
|
the scope name argument.
|
|
It will try to figure out by checking if the first argument
|
|
is string or not.
|
|
|
|
Returns:
|
|
A decorator used to register a layer.
|
|
|
|
Example:
|
|
|
|
.. code-block:: python
|
|
|
|
@layer_register(use_scope=True)
|
|
def add10(x):
|
|
return x + tf.get_variable('W', shape=[10])
|
|
"""
|
|
|
|
def wrapper(func):
|
|
@wraps(func)
|
|
def wrapped_func(*args, **kwargs):
|
|
assert args[0] is not None, args
|
|
if use_scope:
|
|
name, inputs = args[0], args[1]
|
|
args = args[1:] # actual positional args used to call func
|
|
assert isinstance(name, six.string_types), "First argument for \"{}\" should be a string. ".format(
|
|
func.__name__) + "Did you forget to specify the name of the layer?"
|
|
else:
|
|
assert not log_shape
|
|
if isinstance(args[0], six.string_types):
|
|
if use_scope is False:
|
|
logger.warn(
|
|
"Please call layer {} without the first scope name argument, "
|
|
"or register the layer with use_scope=None to allow "
|
|
"two calling methods.".format(func.__name__))
|
|
name, inputs = args[0], args[1]
|
|
args = args[1:] # actual positional args used to call func
|
|
else:
|
|
inputs = args[0]
|
|
name = None
|
|
if not (isinstance(inputs, (tf.Tensor, tf.Variable)) or
|
|
(isinstance(inputs, (list, tuple)) and
|
|
isinstance(inputs[0], (tf.Tensor, tf.Variable)))):
|
|
raise ValueError("Invalid inputs to layer: " + str(inputs))
|
|
|
|
# use kwargs from current argument scope
|
|
actual_args = copy.copy(get_arg_scope()[func.__name__])
|
|
# explicit kwargs overwrite argscope
|
|
actual_args.update(kwargs)
|
|
# if six.PY3:
|
|
# # explicit positional args also override argscope. only work in PY3
|
|
# posargmap = inspect.signature(func).bind_partial(*args).arguments
|
|
# for k in six.iterkeys(posargmap):
|
|
# if k in actual_args:
|
|
# del actual_args[k]
|
|
|
|
if name is not None: # use scope
|
|
with tfv1.variable_scope(name) as scope:
|
|
# this name is only used to surpress logging, doesn't hurt to do some heuristics
|
|
scope_name = re.sub('tower[0-9]+/', '', scope.name)
|
|
do_log_shape = log_shape and scope_name not in _LAYER_LOGGED
|
|
if do_log_shape:
|
|
logger.info("{} input: {}".format(scope.name, get_shape_str(inputs)))
|
|
|
|
# run the actual function
|
|
outputs = func(*args, **actual_args)
|
|
|
|
if do_log_shape:
|
|
# log shape info and add activation
|
|
logger.info("{} output: {}".format(
|
|
scope.name, get_shape_str(outputs)))
|
|
_LAYER_LOGGED.add(scope_name)
|
|
|
|
if hasattr(outputs, 'info'):
|
|
logger.info("{} flops (multi-add): {}".format(
|
|
scope.name, outputs.info.flops))
|
|
else:
|
|
# run the actual function
|
|
outputs = func(*args, **actual_args)
|
|
return outputs
|
|
|
|
wrapped_func.symbolic_function = func # attribute to access the underlying function object
|
|
wrapped_func.use_scope = use_scope
|
|
_register(func.__name__, wrapped_func)
|
|
return wrapped_func
|
|
|
|
return wrapper
|