fix torch se block + channel shuffle block
This commit is contained in:
Родитель
186d7abb99
Коммит
1e7e0a35b3
|
@ -13,7 +13,7 @@ class TorchBlock(BaseBlock):
|
|||
def test_block(self):
|
||||
input_data = torch.randn(1, self.config["CIN"], self.config["HW"], self.config["HW"])
|
||||
output_data = self.get_model()(input_data)
|
||||
print("output size: ", output_data.shape)
|
||||
print("input size:", input_data.shape, "output size: ", output_data.shape)
|
||||
|
||||
def save_model(self, save_path):
|
||||
model = self.get_model()
|
||||
|
@ -629,10 +629,12 @@ class ChannelShuffle(TorchBlock):
|
|||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.channel_shuffle = nn.ChannelShuffle(2)
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self.channel_shuffle(inputs)
|
||||
_, c, h, w = list(inputs.shape)
|
||||
x = torch.reshape(inputs, [-1, c // 2, 2, h, w])
|
||||
x = torch.transpose(x, 2, 1)
|
||||
x = torch.reshape(x, [-1, c, h, w])
|
||||
return x
|
||||
|
||||
return Model()
|
||||
|
|
|
@ -108,8 +108,8 @@ class SE(BaseOperator):
|
|||
super().__init__()
|
||||
cin = input_shape[0]
|
||||
self.avgpool = nn.AdaptiveAvgPool2d([1, 1])
|
||||
self.conv1 = nn.Conv2d(cin, cin // 4, kernel_size=1, stride=1, padding='same')
|
||||
self.conv2 = nn.Conv2d(cin // 4, cin, kernel_size=1, stride=1, padding='same')
|
||||
self.conv1 = nn.Conv2d(cin, cin // 4, kernel_size=1, stride=1, padding=0)
|
||||
self.conv2 = nn.Conv2d(cin // 4, cin, kernel_size=1, stride=1, padding=0)
|
||||
self.relu = nn.ReLU()
|
||||
self.hswish = nn.Hardswish()
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче