Add files via upload
This commit is contained in:
Родитель
2bd1ed9552
Коммит
432d95036b
|
@ -26,34 +26,6 @@ class RGA_Module(nn.Module):
|
|||
|
||||
self.inter_channel = in_channel // cha_ratio
|
||||
self.inter_spatial = in_spatial // spa_ratio
|
||||
|
||||
# Embedding functions for modeling relations
|
||||
if self.use_spatial:
|
||||
self.theta_spatial = nn.Sequential(
|
||||
nn.Conv2d(in_channels=self.in_channel, out_channels=self.inter_channel,
|
||||
kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(self.inter_channel),
|
||||
nn.ReLU()
|
||||
)
|
||||
self.phi_spatial = nn.Sequential(
|
||||
nn.Conv2d(in_channels=self.in_channel, out_channels=self.inter_channel,
|
||||
kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(self.inter_channel),
|
||||
nn.ReLU()
|
||||
)
|
||||
if self.use_channel:
|
||||
self.theta_channel = nn.Sequential(
|
||||
nn.Conv2d(in_channels=self.in_spatial, out_channels=self.inter_spatial,
|
||||
kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(self.inter_spatial),
|
||||
nn.ReLU()
|
||||
)
|
||||
self.phi_channel = nn.Sequential(
|
||||
nn.Conv2d(in_channels=self.in_spatial, out_channels=self.inter_spatial,
|
||||
kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(self.inter_spatial),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
# Embedding functions for original features
|
||||
if self.use_spatial:
|
||||
|
@ -110,6 +82,34 @@ class RGA_Module(nn.Module):
|
|||
kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(1)
|
||||
)
|
||||
|
||||
# Embedding functions for modeling relations
|
||||
if self.use_spatial:
|
||||
self.theta_spatial = nn.Sequential(
|
||||
nn.Conv2d(in_channels=self.in_channel, out_channels=self.inter_channel,
|
||||
kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(self.inter_channel),
|
||||
nn.ReLU()
|
||||
)
|
||||
self.phi_spatial = nn.Sequential(
|
||||
nn.Conv2d(in_channels=self.in_channel, out_channels=self.inter_channel,
|
||||
kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(self.inter_channel),
|
||||
nn.ReLU()
|
||||
)
|
||||
if self.use_channel:
|
||||
self.theta_channel = nn.Sequential(
|
||||
nn.Conv2d(in_channels=self.in_spatial, out_channels=self.inter_spatial,
|
||||
kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(self.inter_spatial),
|
||||
nn.ReLU()
|
||||
)
|
||||
self.phi_channel = nn.Sequential(
|
||||
nn.Conv2d(in_channels=self.in_spatial, out_channels=self.inter_spatial,
|
||||
kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(self.inter_spatial),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.size()
|
||||
|
@ -156,5 +156,4 @@ class RGA_Module(nn.Module):
|
|||
W_yc = self.W_channel(yc).transpose(1, 2)
|
||||
out = F.sigmoid(W_yc) * x
|
||||
|
||||
return out
|
||||
|
||||
return out
|
Загрузка…
Ссылка в новой задаче