79 строки
2.6 KiB
Python
79 строки
2.6 KiB
Python
"""MemNet"""
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.autograd import Variable
|
|
|
|
|
|
class MemNet(nn.Module):
|
|
def __init__(self, in_channels, channels, num_memblock, num_resblock):
|
|
super(MemNet, self).__init__()
|
|
self.image_channels = in_channels
|
|
self.feature_extractor = BNReLUConv(in_channels, channels)
|
|
self.reconstructor = BNReLUConv(channels, in_channels)
|
|
self.dense_memory = nn.ModuleList(
|
|
[MemoryBlock(channels, num_resblock, i+1) for i in range(num_memblock)]
|
|
)
|
|
|
|
def forward(self, x):
|
|
# x = x.contiguous()
|
|
residual = x
|
|
out = self.feature_extractor(x)
|
|
ys = [out]
|
|
for memory_block in self.dense_memory:
|
|
out = memory_block(out, ys)
|
|
out = self.reconstructor(out)
|
|
out = out + residual
|
|
|
|
return out
|
|
|
|
|
|
class MemoryBlock(nn.Module):
|
|
"""Note: num_memblock denotes the number of MemoryBlock currently"""
|
|
def __init__(self, channels, num_resblock, num_memblock):
|
|
super(MemoryBlock, self).__init__()
|
|
self.recursive_unit = nn.ModuleList(
|
|
[ResidualBlock(channels) for i in range(num_resblock)]
|
|
)
|
|
self.gate_unit = BNReLUConv((num_resblock+num_memblock) * channels, channels, 1, 1, 0)
|
|
|
|
def forward(self, x, ys):
|
|
"""ys is a list which contains long-term memory coming from previous memory block
|
|
xs denotes the short-term memory coming from recursive unit
|
|
"""
|
|
xs = []
|
|
residual = x
|
|
for layer in self.recursive_unit:
|
|
x = layer(x)
|
|
xs.append(x)
|
|
|
|
gate_out = self.gate_unit(torch.cat(xs+ys, 1))
|
|
ys.append(gate_out)
|
|
return gate_out
|
|
|
|
|
|
class ResidualBlock(torch.nn.Module):
|
|
"""ResidualBlock
|
|
introduced in: https://arxiv.org/abs/1512.03385
|
|
x - Relu - Conv - Relu - Conv - x
|
|
"""
|
|
|
|
def __init__(self, channels, k=3, s=1, p=1):
|
|
super(ResidualBlock, self).__init__()
|
|
self.relu_conv1 = BNReLUConv(channels, channels, k, s, p)
|
|
self.relu_conv2 = BNReLUConv(channels, channels, k, s, p)
|
|
|
|
def forward(self, x):
|
|
residual = x
|
|
out = self.relu_conv1(x)
|
|
out = self.relu_conv2(out)
|
|
out = out + residual
|
|
return out
|
|
|
|
|
|
class BNReLUConv(nn.Sequential):
|
|
def __init__(self, in_channels, channels, k=3, s=1, p=1, inplace=True):
|
|
super(BNReLUConv, self).__init__()
|
|
self.add_module('bn', nn.BatchNorm2d(in_channels))
|
|
self.add_module('relu', nn.ReLU(inplace=inplace))
|
|
self.add_module('conv', nn.Conv2d(in_channels, channels, k, s, p, bias=False)) |