зеркало из https://github.com/microsoft/LQ-Nets.git
75 строки
2.5 KiB
Python
75 строки
2.5 KiB
Python
|
#!/usr/bin/env python
|
||
|
# -*- coding: utf-8 -*-
|
||
|
# File: densenet_model.py
|
||
|
|
||
|
import math
|
||
|
|
||
|
import tensorflow as tf
|
||
|
from tensorflow.contrib.layers import variance_scaling_initializer
|
||
|
from tensorpack.models import *
|
||
|
from tensorpack.tfutils.argscope import argscope, get_arg_scope
|
||
|
|
||
|
from learned_quantization import Conv2DQuant, QuantizedActiv
|
||
|
|
||
|
GROWTH_RATE = 32
|
||
|
REDUCTION = 0.5
|
||
|
|
||
|
|
||
|
def add_layer(name, l):
|
||
|
shape = l.get_shape().as_list()
|
||
|
in_channel = shape[1]
|
||
|
with tf.variable_scope(name) as scope:
|
||
|
c = Conv2DQuant('conv1x1', l, 4 * GROWTH_RATE, 1)
|
||
|
c = BNReLU('bnrelu_2', c)
|
||
|
c = QuantizedActiv('quant2', c)
|
||
|
c = Conv2DQuant('conv3x3', c, GROWTH_RATE, 3)
|
||
|
c = BNReLU('bnrelu_3', c)
|
||
|
c = QuantizedActiv('quant3', c)
|
||
|
l = tf.concat([c, l], 1)
|
||
|
return l
|
||
|
|
||
|
|
||
|
def add_transition(name, l):
|
||
|
shape = l.get_shape().as_list()
|
||
|
in_channel = shape[1]
|
||
|
out_channel = math.floor(in_channel * REDUCTION)
|
||
|
with tf.variable_scope(name) as scope:
|
||
|
l = Conv2DQuant('conv1', l, out_channel, 1, stride=1, use_bias=False)
|
||
|
l = AvgPooling('pool', l, 2)
|
||
|
return l
|
||
|
|
||
|
|
||
|
def add_dense_block(l, name, N, last=False, first=False):
|
||
|
with tf.variable_scope(name) as scope:
|
||
|
if first:
|
||
|
l = BNReLU('first', l)
|
||
|
l = QuantizedActiv('quant_first', l)
|
||
|
for i in range(N):
|
||
|
l = add_layer('dense_layer.{}'.format(i), l)
|
||
|
if not last:
|
||
|
l = add_transition('transition', l)
|
||
|
return l
|
||
|
|
||
|
|
||
|
def densenet_backbone(image, qw=1):
|
||
|
with argscope(Conv2DQuant, nl=tf.identity, use_bias=False,
|
||
|
W_init=variance_scaling_initializer(mode='FAN_IN'),
|
||
|
data_format=get_arg_scope()['Conv2D']['data_format'],
|
||
|
nbit=qw,
|
||
|
is_quant=True if qw > 0 else False):
|
||
|
logits = (LinearWrap(image)
|
||
|
.Conv2DQuant('conv1', 2 * GROWTH_RATE, 7, stride=2, nl=BNReLU, is_quant=False)
|
||
|
.MaxPooling('pool1', shape=3, stride=2, padding='SAME')
|
||
|
# 56
|
||
|
.apply(add_dense_block, 'block0', 6)
|
||
|
# 28
|
||
|
.apply(add_dense_block, 'block1', 12)
|
||
|
# 14
|
||
|
.apply(add_dense_block, 'block2', 24)
|
||
|
# 7
|
||
|
.apply(add_dense_block, 'block3', 16, last=True)
|
||
|
.BNReLU('bnrelu_last')
|
||
|
.GlobalAvgPooling('gap')
|
||
|
.FullyConnected('linear', out_dim=1000, nl=tf.identity, W_init=variance_scaling_initializer(mode='FAN_IN'))())
|
||
|
return logits
|