This commit is contained in:
Cuiling Lan 2020-04-15 04:41:49 +08:00 коммит произвёл GitHub
Родитель 2bd1ed9552
Коммит 432d95036b
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 29 добавлений и 30 удалений

Просмотреть файл

@ -26,34 +26,6 @@ class RGA_Module(nn.Module):
self.inter_channel = in_channel // cha_ratio
self.inter_spatial = in_spatial // spa_ratio
# Embedding functions for modeling relations
if self.use_spatial:
self.theta_spatial = nn.Sequential(
nn.Conv2d(in_channels=self.in_channel, out_channels=self.inter_channel,
kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(self.inter_channel),
nn.ReLU()
)
self.phi_spatial = nn.Sequential(
nn.Conv2d(in_channels=self.in_channel, out_channels=self.inter_channel,
kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(self.inter_channel),
nn.ReLU()
)
if self.use_channel:
self.theta_channel = nn.Sequential(
nn.Conv2d(in_channels=self.in_spatial, out_channels=self.inter_spatial,
kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(self.inter_spatial),
nn.ReLU()
)
self.phi_channel = nn.Sequential(
nn.Conv2d(in_channels=self.in_spatial, out_channels=self.inter_spatial,
kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(self.inter_spatial),
nn.ReLU()
)
# Embedding functions for original features
if self.use_spatial:
@ -110,6 +82,34 @@ class RGA_Module(nn.Module):
kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(1)
)
# Embedding functions for modeling relations
if self.use_spatial:
self.theta_spatial = nn.Sequential(
nn.Conv2d(in_channels=self.in_channel, out_channels=self.inter_channel,
kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(self.inter_channel),
nn.ReLU()
)
self.phi_spatial = nn.Sequential(
nn.Conv2d(in_channels=self.in_channel, out_channels=self.inter_channel,
kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(self.inter_channel),
nn.ReLU()
)
if self.use_channel:
self.theta_channel = nn.Sequential(
nn.Conv2d(in_channels=self.in_spatial, out_channels=self.inter_spatial,
kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(self.inter_spatial),
nn.ReLU()
)
self.phi_channel = nn.Sequential(
nn.Conv2d(in_channels=self.in_spatial, out_channels=self.inter_spatial,
kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(self.inter_spatial),
nn.ReLU()
)
def forward(self, x):
b, c, h, w = x.size()
@ -156,5 +156,4 @@ class RGA_Module(nn.Module):
W_yc = self.W_channel(yc).transpose(1, 2)
out = F.sigmoid(W_yc) * x
return out
return out