зеркало из https://github.com/microsoft/nni.git
rm useless file (#5042)
This commit is contained in:
Родитель
a0fd003671
Коммит
60e1e01f62
|
@ -1,7 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .layer_choice_only import LayerChoiceOnlySearchSpace
|
||||
from .mutable_scope import SpaceWithMutableScope
|
||||
from .naive import NaiveSearchSpace
|
||||
from .nested import NestedSpace
|
||||
|
|
|
@ -1,38 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from nni.nas.pytorch.mutables import LayerChoice
|
||||
|
||||
|
||||
class LayerChoiceOnlySearchSpace(nn.Module):
|
||||
def __init__(self, test_case):
|
||||
super().__init__()
|
||||
self.test_case = test_case
|
||||
self.conv1 = LayerChoice([nn.Conv2d(3, 6, 3, padding=1), nn.Conv2d(3, 6, 5, padding=2)])
|
||||
self.pool = nn.MaxPool2d(2, 2)
|
||||
self.conv2 = LayerChoice([nn.Conv2d(6, 16, 3, padding=1), nn.Conv2d(6, 16, 5, padding=2)],
|
||||
return_mask=True)
|
||||
self.conv3 = nn.Conv2d(16, 16, 1)
|
||||
self.bn = nn.BatchNorm2d(16)
|
||||
|
||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Linear(16, 10)
|
||||
|
||||
def forward(self, x):
|
||||
bs = x.size(0)
|
||||
|
||||
x = self.pool(F.relu(self.conv1(x)))
|
||||
x0, mask = self.conv2(x)
|
||||
self.test_case.assertEqual(mask.size(), torch.Size([2]))
|
||||
x1 = F.relu(self.conv3(x0))
|
||||
|
||||
x = self.pool(self.bn(x1))
|
||||
self.test_case.assertEqual(mask.size(), torch.Size([2]))
|
||||
|
||||
x = self.gap(x).view(bs, -1)
|
||||
x = self.fc(x)
|
||||
return x
|
Загрузка…
Ссылка в новой задаче