fix torch se block + channel shuffle block

This commit is contained in:
jiahangxu 2022-03-28 10:31:33 -04:00
Родитель 186d7abb99
Коммит 1e7e0a35b3
2 изменённых файлов: 7 добавлений и 5 удалений

Просмотреть файл

@ -13,7 +13,7 @@ class TorchBlock(BaseBlock):
def test_block(self):
input_data = torch.randn(1, self.config["CIN"], self.config["HW"], self.config["HW"])
output_data = self.get_model()(input_data)
print("output size: ", output_data.shape)
print("input size:", input_data.shape, "output size: ", output_data.shape)
def save_model(self, save_path):
model = self.get_model()
@ -629,10 +629,12 @@ class ChannelShuffle(TorchBlock):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.channel_shuffle = nn.ChannelShuffle(2)
def forward(self, inputs):
x = self.channel_shuffle(inputs)
_, c, h, w = list(inputs.shape)
x = torch.reshape(inputs, [-1, c // 2, 2, h, w])
x = torch.transpose(x, 2, 1)
x = torch.reshape(x, [-1, c, h, w])
return x
return Model()

Просмотреть файл

@ -108,8 +108,8 @@ class SE(BaseOperator):
super().__init__()
cin = input_shape[0]
self.avgpool = nn.AdaptiveAvgPool2d([1, 1])
self.conv1 = nn.Conv2d(cin, cin // 4, kernel_size=1, stride=1, padding='same')
self.conv2 = nn.Conv2d(cin // 4, cin, kernel_size=1, stride=1, padding='same')
self.conv1 = nn.Conv2d(cin, cin // 4, kernel_size=1, stride=1, padding=0)
self.conv2 = nn.Conv2d(cin // 4, cin, kernel_size=1, stride=1, padding=0)
self.relu = nn.ReLU()
self.hswish = nn.Hardswish()