зеркало из https://github.com/microsoft/archai.git
nasbench101 model creation
This commit is contained in:
Родитель
7f22aab21d
Коммит
faaa719cba
|
@ -16,7 +16,7 @@ from __future__ import print_function
|
|||
import numpy as np
|
||||
import math
|
||||
|
||||
from base_ops import *
|
||||
from .base_ops import *
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -24,20 +24,20 @@ import torch.nn.functional as F
|
|||
|
||||
|
||||
class Network(nn.Module):
|
||||
def __init__(self, spec, args):
|
||||
def __init__(self, spec, stem_out_channels, num_stacks, num_modules_per_stack, num_labels):
|
||||
super(Network, self).__init__()
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
in_channels = 3
|
||||
out_channels = args.stem_out_channels # out channels for the model stem
|
||||
out_channels = stem_out_channels # out channels for the model stem
|
||||
|
||||
# initial stem convolution
|
||||
stem_conv = ConvBnRelu(in_channels, out_channels, 3, 1, 1)
|
||||
self.layers.append(stem_conv)
|
||||
|
||||
in_channels = out_channels
|
||||
for stack_num in range(args.num_stacks):
|
||||
for stack_num in range(num_stacks):
|
||||
if stack_num > 0:
|
||||
# downsampling by maxpool doesn't change the channel
|
||||
downsample = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
|
@ -45,12 +45,12 @@ class Network(nn.Module):
|
|||
|
||||
out_channels *= 2
|
||||
|
||||
for module_num in range(args.num_modules_per_stack):
|
||||
for module_num in range(num_modules_per_stack):
|
||||
cell = Cell(spec, in_channels, out_channels)
|
||||
self.layers.append(cell)
|
||||
in_channels = out_channels
|
||||
|
||||
self.classifier = nn.Linear(out_channels, args.num_labels)
|
||||
self.classifier = nn.Linear(out_channels, num_labels)
|
||||
|
||||
self._initialize_weights()
|
||||
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .model import Network
|
||||
from .model_spec import ModelSpec
|
||||
|
||||
|
||||
VERTEX_OPS = ['input', 'conv1x1-bn-relu', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'maxpool3x3', 'output']
|
||||
|
||||
EXAMPLE_DESC_MATRIX = [[0, 1, 1, 1, 0, 1, 0],
|
||||
[0, 0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 0, 0]]
|
||||
|
||||
def build(desc_matrix:List[List[int]], vertex_ops=VERTEX_OPS, device=None,
|
||||
stem_out_channels=128, num_stacks=3, num_modules_per_stack=3, num_labels=10)->nn.Module:
|
||||
model_spec = ModelSpec(desc_matrix, vertex_ops)
|
||||
model = Network(model_spec, stem_out_channels, num_stacks, num_modules_per_stack, num_labels)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
|
||||
model.to(device)
|
||||
return model
|
|
@ -11,7 +11,7 @@ from __future__ import print_function
|
|||
import copy
|
||||
import numpy as np
|
||||
|
||||
import graph_util
|
||||
from . import graph_util
|
||||
|
||||
# Graphviz is optional and only required for visualization.
|
||||
try:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
__include__: "darts.yaml" # just use darts defaults
|
||||
|
||||
nas:
|
||||
search:
|
||||
eval:
|
||||
model_desc:
|
||||
params: {
|
||||
'cell_matrix' : [[0, 1, 1, 1, 0, 1, 0],
|
||||
|
@ -19,4 +19,22 @@ nas:
|
|||
stem_multiplier: 1 # output channels for stem = 128
|
||||
init_node_ch: 128 # num of input/output channels for nodes in 1st cell
|
||||
model_post_op: 'pool_mean_tensor'
|
||||
n_cells: 9 # 3 stacks, each stack with 3 cells
|
||||
n_cells: 9 # 3 stacks, each stack with 3 cells
|
||||
loader:
|
||||
aug: '' # additional augmentations to use, for ex, fa_reduced_cifar10, arsaug, autoaug_cifar10, autoaug_extend
|
||||
cutout: 0 # cutout length, use cutout augmentation when > 0
|
||||
train_batch: 128 # 96 is too aggressive for 1080Ti, better set it to 68
|
||||
trainer:
|
||||
aux_weight: 0.0
|
||||
drop_path_prob: 0.0 # probability that given edge will be dropped
|
||||
grad_clip: 5.0 # grads above this value is clipped
|
||||
epochs: 100
|
||||
optimizer:
|
||||
type: 'sgd'
|
||||
lr: 0.025 # init learning rate
|
||||
decay: 1.0e-4 # pytorch default is 0.0
|
||||
momentum: 0.9 # pytorch default is 0.0
|
||||
nesterov: False # pytorch default is False
|
||||
lr_schedule:
|
||||
type: 'cosine'
|
||||
min_lr: 0.0 # min learning rate to se bet in eta_min param of scheduler
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
from archai.algos.nasbench101 import model_builder
|
||||
from archai import cifar10_models
|
||||
from archai.common.trainer import Trainer
|
||||
from archai.common.config import Config
|
||||
from archai.common.common import common_init
|
||||
from archai.datasets import data
|
||||
|
||||
|
||||
def main():
|
||||
conf = common_init(config_filepath='confs/algos/resnet.yaml')
|
||||
conf_eval = conf['nas']['eval']
|
||||
conf_loader = conf_eval['loader']
|
||||
conf_trainer = conf_eval['trainer']
|
||||
|
||||
model = model_builder.build(model_builder.EXAMPLE_DESC_MATRIX)
|
||||
|
||||
train_dl, _, test_dl = data.get_data(conf_loader)
|
||||
|
||||
trainer = Trainer(conf_trainer, model)
|
||||
trainer.fit(train_dl, test_dl)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Загрузка…
Ссылка в новой задаче