зеркало из https://github.com/microsoft/LQ-Nets.git
381 строка
15 KiB
Python
381 строка
15 KiB
Python
|
#!/usr/bin/env python
|
||
|
# -*- coding: utf-8 -*-
|
||
|
# File: learned_quantization.py
|
||
|
|
||
|
import tensorflow as tf
|
||
|
from tensorflow.contrib.framework import add_model_variable
|
||
|
from tensorflow.python.training import moving_averages
|
||
|
from tensorpack.models import *
|
||
|
from tensorpack.tfutils.tower import get_current_tower_context
|
||
|
|
||
|
MOVING_AVERAGES_FACTOR = 0.9
|
||
|
EPS = 0.0001
|
||
|
NORM_PPF_0_75 = 0.6745
|
||
|
|
||
|
|
||
|
@layer_register()
|
||
|
def QuantizedActiv(x, nbit=2):
|
||
|
"""
|
||
|
Quantize activation.
|
||
|
Args:
|
||
|
x (tf.Tensor): a 4D tensor.
|
||
|
nbit (int): number of bits of quantized activation. Defaults to 2.
|
||
|
Returns:
|
||
|
tf.Tensor with attribute `variables`.
|
||
|
Variable Names:
|
||
|
* ``basis``: basis of quantized activation.
|
||
|
Note:
|
||
|
About multi-GPU training: moving averages across GPUs are not aggregated.
|
||
|
Batch statistics are computed by main training tower. This is consistent with most frameworks.
|
||
|
"""
|
||
|
init_basis = [(NORM_PPF_0_75 * 2 / (2 ** nbit - 1)) * (2. ** i) for i in range(nbit)]
|
||
|
init_basis = tf.constant_initializer(init_basis)
|
||
|
bit_dims = [nbit, 1]
|
||
|
num_levels = 2 ** nbit
|
||
|
# initialize level multiplier
|
||
|
init_level_multiplier = []
|
||
|
for i in range(0, num_levels):
|
||
|
level_multiplier_i = [0. for j in range(nbit)]
|
||
|
level_number = i
|
||
|
for j in range(nbit):
|
||
|
level_multiplier_i[j] = float(level_number % 2)
|
||
|
level_number = level_number // 2
|
||
|
init_level_multiplier.append(level_multiplier_i)
|
||
|
# initialize threshold multiplier
|
||
|
init_thrs_multiplier = []
|
||
|
for i in range(1, num_levels):
|
||
|
thrs_multiplier_i = [0. for j in range(num_levels)]
|
||
|
thrs_multiplier_i[i - 1] = 0.5
|
||
|
thrs_multiplier_i[i] = 0.5
|
||
|
init_thrs_multiplier.append(thrs_multiplier_i)
|
||
|
|
||
|
with tf.variable_scope('ActivationQuantization'):
|
||
|
basis = tf.get_variable(
|
||
|
'basis', bit_dims, tf.float32,
|
||
|
initializer=init_basis,
|
||
|
trainable=False)
|
||
|
|
||
|
ctx = get_current_tower_context() # current tower context
|
||
|
# calculate levels and sort
|
||
|
level_codes = tf.constant(init_level_multiplier)
|
||
|
levels = tf.matmul(level_codes, basis)
|
||
|
levels, sort_id = tf.nn.top_k(tf.transpose(levels, [1, 0]), num_levels)
|
||
|
levels = tf.reverse(levels, [-1])
|
||
|
sort_id = tf.reverse(sort_id, [-1])
|
||
|
levels = tf.transpose(levels, [1, 0])
|
||
|
sort_id = tf.transpose(sort_id, [1, 0])
|
||
|
# calculate threshold
|
||
|
thrs_multiplier = tf.constant(init_thrs_multiplier)
|
||
|
thrs = tf.matmul(thrs_multiplier, levels)
|
||
|
# calculate output y and its binary code
|
||
|
y = tf.zeros_like(x) # output
|
||
|
reshape_x = tf.reshape(x, [-1])
|
||
|
zero_dims = tf.stack([tf.shape(reshape_x)[0], nbit])
|
||
|
bits_y = tf.fill(zero_dims, 0.)
|
||
|
zero_y = tf.zeros_like(x)
|
||
|
zero_bits_y = tf.fill(zero_dims, 0.)
|
||
|
for i in range(num_levels - 1):
|
||
|
g = tf.greater(x, thrs[i])
|
||
|
y = tf.where(g, zero_y + levels[i + 1], y)
|
||
|
bits_y = tf.where(tf.reshape(g, [-1]), zero_bits_y + level_codes[sort_id[i + 1][0]], bits_y)
|
||
|
# training
|
||
|
if ctx.is_main_training_tower:
|
||
|
BT = tf.matrix_transpose(bits_y)
|
||
|
# calculate BTxB
|
||
|
BTxB = []
|
||
|
for i in range(nbit):
|
||
|
for j in range(nbit):
|
||
|
BTxBij = tf.multiply(BT[i], BT[j])
|
||
|
BTxBij = tf.reduce_sum(BTxBij)
|
||
|
BTxB.append(BTxBij)
|
||
|
BTxB = tf.reshape(tf.stack(values=BTxB), [nbit, nbit])
|
||
|
BTxB_inv = tf.matrix_inverse(BTxB)
|
||
|
# calculate BTxX
|
||
|
BTxX = []
|
||
|
for i in range(nbit):
|
||
|
BTxXi0 = tf.multiply(BT[i], reshape_x)
|
||
|
BTxXi0 = tf.reduce_sum(BTxXi0)
|
||
|
BTxX.append(BTxXi0)
|
||
|
BTxX = tf.reshape(tf.stack(values=BTxX), [nbit, 1])
|
||
|
|
||
|
new_basis = tf.matmul(BTxB_inv, BTxX) # calculate new basis
|
||
|
# create moving averages op
|
||
|
updata_moving_basis = moving_averages.assign_moving_average(
|
||
|
basis, new_basis, MOVING_AVERAGES_FACTOR)
|
||
|
add_model_variable(basis)
|
||
|
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, updata_moving_basis)
|
||
|
|
||
|
for i in range(nbit):
|
||
|
tf.summary.scalar('basis%d' % i, new_basis[i][0])
|
||
|
|
||
|
x_clip = tf.minimum(x, levels[num_levels - 1]) # gradient clip
|
||
|
y = x_clip + tf.stop_gradient(-x_clip) + tf.stop_gradient(y) # gradient: y=clip(x)
|
||
|
y.variables = VariableHolder(basis=basis)
|
||
|
return y
|
||
|
|
||
|
|
||
|
def QuantizedWeight(name, x, n, nbit=2):
|
||
|
"""
|
||
|
Quantize weight.
|
||
|
Args:
|
||
|
x (tf.Tensor): a 4D tensor.
|
||
|
Must have known number of channels, but can have other unknown dimensions.
|
||
|
name (str): operator's name.
|
||
|
n (int or double): variance of weight initialization.
|
||
|
nbit (int): number of bits of quantized weight. Defaults to 2.
|
||
|
Returns:
|
||
|
tf.Tensor with attribute `variables`.
|
||
|
Variable Names:
|
||
|
* ``basis``: basis of quantized weight.
|
||
|
Note:
|
||
|
About multi-GPU training: moving averages across GPUs are not aggregated.
|
||
|
Batch statistics are computed by main training tower. This is consistent with most frameworks.
|
||
|
"""
|
||
|
num_filters = x.get_shape().as_list()[-1]
|
||
|
init_basis = []
|
||
|
base = NORM_PPF_0_75 * ((2. / n) ** 0.5) / (2 ** (nbit - 1))
|
||
|
for j in range(nbit):
|
||
|
init_basis.append([(2 ** j) * base for i in range(num_filters)])
|
||
|
init_basis = tf.constant_initializer(init_basis)
|
||
|
bit_dims = [nbit, num_filters]
|
||
|
num_levels = 2 ** nbit
|
||
|
delta = EPS
|
||
|
# initialize level multiplier
|
||
|
init_level_multiplier = []
|
||
|
for i in range(num_levels):
|
||
|
level_multiplier_i = [0. for j in range(nbit)]
|
||
|
level_number = i
|
||
|
for j in range(nbit):
|
||
|
binary_code = level_number % 2
|
||
|
if binary_code == 0:
|
||
|
binary_code = -1
|
||
|
level_multiplier_i[j] = float(binary_code)
|
||
|
level_number = level_number // 2
|
||
|
init_level_multiplier.append(level_multiplier_i)
|
||
|
# initialize threshold multiplier
|
||
|
init_thrs_multiplier = []
|
||
|
for i in range(1, num_levels):
|
||
|
thrs_multiplier_i = [0. for j in range(num_levels)]
|
||
|
thrs_multiplier_i[i - 1] = 0.5
|
||
|
thrs_multiplier_i[i] = 0.5
|
||
|
init_thrs_multiplier.append(thrs_multiplier_i)
|
||
|
|
||
|
with tf.variable_scope(name):
|
||
|
basis = tf.get_variable(
|
||
|
'basis', bit_dims, tf.float32,
|
||
|
initializer=init_basis,
|
||
|
trainable=False)
|
||
|
level_codes = tf.constant(init_level_multiplier)
|
||
|
thrs_multiplier = tf.constant(init_thrs_multiplier)
|
||
|
sum_multiplier = tf.constant(1., shape=[1, tf.reshape(x, [-1, num_filters]).get_shape()[0]])
|
||
|
sum_multiplier_basis = tf.constant(1., shape=[1, nbit])
|
||
|
|
||
|
ctx = get_current_tower_context() # current tower context
|
||
|
# calculate levels and sort
|
||
|
levels = tf.matmul(level_codes, basis)
|
||
|
levels, sort_id = tf.nn.top_k(tf.transpose(levels, [1, 0]), num_levels)
|
||
|
levels = tf.reverse(levels, [-1])
|
||
|
sort_id = tf.reverse(sort_id, [-1])
|
||
|
levels = tf.transpose(levels, [1, 0])
|
||
|
sort_id = tf.transpose(sort_id, [1, 0])
|
||
|
# calculate threshold
|
||
|
thrs = tf.matmul(thrs_multiplier, levels)
|
||
|
# calculate level codes per channel
|
||
|
reshape_x = tf.reshape(x, [-1, num_filters])
|
||
|
level_codes_channelwise_dims = tf.stack([num_levels * num_filters, nbit])
|
||
|
level_codes_channelwise = tf.fill(level_codes_channelwise_dims, 0.)
|
||
|
for i in range(num_levels):
|
||
|
eq = tf.equal(sort_id, i)
|
||
|
level_codes_channelwise = tf.where(tf.reshape(eq, [-1]), level_codes_channelwise + level_codes[i], level_codes_channelwise)
|
||
|
level_codes_channelwise = tf.reshape(level_codes_channelwise, [num_levels, num_filters, nbit])
|
||
|
# calculate output y and its binary code
|
||
|
y = tf.zeros_like(x) + levels[0] # output
|
||
|
zero_dims = tf.stack([tf.shape(reshape_x)[0] * num_filters, nbit])
|
||
|
bits_y = tf.fill(zero_dims, -1.)
|
||
|
zero_y = tf.zeros_like(x)
|
||
|
zero_bits_y = tf.fill(zero_dims, 0.)
|
||
|
zero_bits_y = tf.reshape(zero_bits_y, [-1, num_filters, nbit])
|
||
|
for i in range(num_levels - 1):
|
||
|
g = tf.greater(x, thrs[i])
|
||
|
y = tf.where(g, zero_y + levels[i + 1], y)
|
||
|
bits_y = tf.where(tf.reshape(g, [-1]), tf.reshape(zero_bits_y + level_codes_channelwise[i + 1], [-1, nbit]), bits_y)
|
||
|
bits_y = tf.reshape(bits_y, [-1, num_filters, nbit])
|
||
|
# training
|
||
|
if ctx.is_main_training_tower:
|
||
|
BT = tf.transpose(bits_y, [2, 0, 1])
|
||
|
# calculate BTxB
|
||
|
BTxB = []
|
||
|
for i in range(nbit):
|
||
|
for j in range(nbit):
|
||
|
BTxBij = tf.multiply(BT[i], BT[j])
|
||
|
BTxBij = tf.matmul(sum_multiplier, BTxBij)
|
||
|
if i == j:
|
||
|
mat_one = tf.ones([1, num_filters])
|
||
|
BTxBij = BTxBij + (delta * mat_one) # + E
|
||
|
BTxB.append(BTxBij)
|
||
|
BTxB = tf.reshape(tf.stack(values=BTxB), [nbit, nbit, num_filters])
|
||
|
# calculate inverse of BTxB
|
||
|
if nbit > 2:
|
||
|
BTxB_transpose = tf.transpose(BTxB, [2, 0, 1])
|
||
|
BTxB_inv = tf.matrix_inverse(BTxB_transpose)
|
||
|
BTxB_inv = tf.transpose(BTxB_inv, [1, 2, 0])
|
||
|
elif nbit == 2:
|
||
|
det = tf.multiply(BTxB[0][0], BTxB[1][1]) - tf.multiply(BTxB[0][1], BTxB[1][0])
|
||
|
inv = []
|
||
|
inv.append(BTxB[1][1] / det)
|
||
|
inv.append(-BTxB[0][1] / det)
|
||
|
inv.append(-BTxB[1][0] / det)
|
||
|
inv.append(BTxB[0][0] / det)
|
||
|
BTxB_inv = tf.reshape(tf.stack(values=inv), [nbit, nbit, num_filters])
|
||
|
elif nbit == 1:
|
||
|
BTxB_inv = tf.reciprocal(BTxB)
|
||
|
# calculate BTxX
|
||
|
BTxX = []
|
||
|
for i in range(nbit):
|
||
|
BTxXi0 = tf.multiply(BT[i], reshape_x)
|
||
|
BTxXi0 = tf.matmul(sum_multiplier, BTxXi0)
|
||
|
BTxX.append(BTxXi0)
|
||
|
BTxX = tf.reshape(tf.stack(values=BTxX), [nbit, num_filters])
|
||
|
BTxX = BTxX + (delta * basis) # + basis
|
||
|
# calculate new basis
|
||
|
new_basis = []
|
||
|
for i in range(nbit):
|
||
|
new_basis_i = tf.multiply(BTxB_inv[i], BTxX)
|
||
|
new_basis_i = tf.matmul(sum_multiplier_basis, new_basis_i)
|
||
|
new_basis.append(new_basis_i)
|
||
|
new_basis = tf.reshape(tf.stack(values=new_basis), [nbit, num_filters])
|
||
|
# create moving averages op
|
||
|
updata_moving_basis = moving_averages.assign_moving_average(
|
||
|
basis, new_basis, MOVING_AVERAGES_FACTOR)
|
||
|
add_model_variable(basis)
|
||
|
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, updata_moving_basis)
|
||
|
|
||
|
y = x + tf.stop_gradient(-x) + tf.stop_gradient(y) # gradient: y=x
|
||
|
y.variables = VariableHolder(basis=basis)
|
||
|
return y
|
||
|
|
||
|
|
||
|
@layer_register()
|
||
|
def Conv2DQuant(x, out_channel, kernel_shape,
|
||
|
padding='SAME', stride=1,
|
||
|
W_init=None, b_init=None,
|
||
|
nl=tf.identity, split=1, use_bias=True,
|
||
|
data_format='NHWC', is_quant=True, nbit=1, fc=False):
|
||
|
"""
|
||
|
2D convolution on 4D inputs.
|
||
|
Args:
|
||
|
x (tf.Tensor): a 4D tensor.
|
||
|
Must have known number of channels, but can have other unknown dimensions.
|
||
|
out_channel (int): number of output channel.
|
||
|
kernel_shape: (h, w) tuple or a int.
|
||
|
stride: (h, w) tuple or a int.
|
||
|
padding (str): 'valid' or 'same'. Case insensitive.
|
||
|
split (int): Split channels as used in Alexnet. Defaults to 1 (no split).
|
||
|
W_init: initializer for W. Defaults to `variance_scaling_initializer`.
|
||
|
b_init: initializer for b. Defaults to zero.
|
||
|
nl: a nonlinearity function.
|
||
|
use_bias (bool): whether to use bias.
|
||
|
data_format (str): 'NHWC' or 'NCHW'. Defaults to 'NHWC'.
|
||
|
is_quant (bool): whether to quantize weight. Defaults to 'True'.
|
||
|
nbit (int): number of bits of quantized weight. Defaults to 1.
|
||
|
fc (bool): whether to convert Conv2D to FullyConnect. Defaults to 'False'.
|
||
|
Returns:
|
||
|
tf.Tensor named ``output`` with attribute `variables`.
|
||
|
Variable Names:
|
||
|
* ``W``: weights
|
||
|
* ``b``: bias
|
||
|
"""
|
||
|
n = kernel_shape * kernel_shape * out_channel
|
||
|
in_shape = x.get_shape().as_list()
|
||
|
channel_axis = 3 if data_format == 'NHWC' else 1
|
||
|
in_channel = in_shape[channel_axis]
|
||
|
assert in_channel is not None, "[Conv2DQuant] Input cannot have unknown channel!"
|
||
|
assert in_channel % split == 0
|
||
|
assert out_channel % split == 0
|
||
|
|
||
|
if fc:
|
||
|
x = tf.reshape(x, [-1, in_channel, 1, 1])
|
||
|
|
||
|
kernel_shape = [kernel_shape, kernel_shape]
|
||
|
padding = padding.upper()
|
||
|
filter_shape = kernel_shape + [in_channel / split, out_channel]
|
||
|
|
||
|
if data_format == 'NCHW':
|
||
|
stride = [1, 1, stride, stride]
|
||
|
else:
|
||
|
stride = [1, stride, stride, 1]
|
||
|
|
||
|
if W_init is None:
|
||
|
W_init = tf.contrib.layers.variance_scaling_initializer()
|
||
|
if b_init is None:
|
||
|
b_init = tf.constant_initializer()
|
||
|
|
||
|
W = tf.get_variable('W', filter_shape, initializer=W_init)
|
||
|
|
||
|
kernel_in = W * 1
|
||
|
tf.summary.scalar('weight', tf.reduce_mean(tf.abs(W)))
|
||
|
if is_quant:
|
||
|
quantized_weight = QuantizedWeight('weight_quant', kernel_in, n, nbit=nbit)
|
||
|
else:
|
||
|
quantized_weight = kernel_in
|
||
|
|
||
|
if use_bias:
|
||
|
b = tf.get_variable('b', [out_channel], initializer=b_init)
|
||
|
|
||
|
if split == 1:
|
||
|
conv = tf.nn.conv2d(x, quantized_weight, stride, padding, data_format=data_format)
|
||
|
else:
|
||
|
inputs = tf.split(x, split, channel_axis)
|
||
|
kernels = tf.split(quantized_weight, split, 3)
|
||
|
outputs = [tf.nn.conv2d(i, k, stride, padding, data_format=data_format)
|
||
|
for i, k in zip(inputs, kernels)]
|
||
|
conv = tf.concat(outputs, channel_axis)
|
||
|
|
||
|
ret = nl(tf.nn.bias_add(conv, b, data_format=data_format) if use_bias else conv, name='output')
|
||
|
ret.variables = VariableHolder(W=W)
|
||
|
if use_bias:
|
||
|
ret.variables.b = b
|
||
|
if fc:
|
||
|
ret = tf.reshape(ret, [-1, out_channel])
|
||
|
return ret
|
||
|
|
||
|
|
||
|
@layer_register(log_shape=False, use_scope=None)
|
||
|
def BNReLUQuant(x):
|
||
|
"""
|
||
|
A shorthand of BatchNormalization + ReLU + QuantizedActiv.
|
||
|
"""
|
||
|
x = BatchNorm('bn', x)
|
||
|
x = tf.nn.relu(x)
|
||
|
x = QuantizedActiv('quant', x)
|
||
|
return x
|
||
|
|
||
|
|
||
|
def getBNReLUQuant(x, name=None):
|
||
|
"""
|
||
|
A shorthand of BatchNormalization + ReLU + QuantizedActiv.
|
||
|
"""
|
||
|
x = BatchNorm('bn', x)
|
||
|
x = tf.nn.relu(x, name=name)
|
||
|
x = QuantizedActiv('quant', x)
|
||
|
return x
|
||
|
|
||
|
|
||
|
def getfcBNReLUQuant(x, name=None):
|
||
|
"""
|
||
|
A shorthand of BatchNormalization + ReLU + QuantizedActiv after FullyConnect.
|
||
|
"""
|
||
|
x = BatchNorm('bn', x, data_format='NHWC', use_scale=False, use_bias=False)
|
||
|
x = tf.nn.relu(x, name=name)
|
||
|
x = QuantizedActiv('quant', x)
|
||
|
return x
|
||
|
|
||
|
|
||
|
def getfcBNReLU(x, name=None):
|
||
|
"""
|
||
|
A shorthand of BatchNormalization + ReLU after FullyConnect.
|
||
|
"""
|
||
|
x = BatchNorm('bn', x, data_format='NHWC', use_scale=False, use_bias=False)
|
||
|
x = tf.nn.relu(x, name=name)
|
||
|
return x
|