From faaa719cba1c35560382fc43fca472658c6e6771 Mon Sep 17 00:00:00 2001 From: Shital Shah Date: Wed, 13 Jan 2021 11:06:47 -0800 Subject: [PATCH] nasbench101 model creation --- archai/algos/nasbench101/__init__.py | 0 archai/algos/nasbench101/model.py | 12 ++++----- archai/algos/nasbench101/model_builder.py | 26 +++++++++++++++++++ archai/algos/nasbench101/model_spec.py | 2 +- confs/algos/nasbench101.yaml | 22 ++++++++++++++-- ...101_demo.py => nasbench101_archai_demo.py} | 0 scripts/misc/nasbench101_pytorch_demo.py | 23 ++++++++++++++++ 7 files changed, 76 insertions(+), 9 deletions(-) create mode 100644 archai/algos/nasbench101/__init__.py rename scripts/misc/{nasbench101_demo.py => nasbench101_archai_demo.py} (100%) create mode 100644 scripts/misc/nasbench101_pytorch_demo.py diff --git a/archai/algos/nasbench101/__init__.py b/archai/algos/nasbench101/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/archai/algos/nasbench101/model.py b/archai/algos/nasbench101/model.py index 74887dac..8b8a0827 100644 --- a/archai/algos/nasbench101/model.py +++ b/archai/algos/nasbench101/model.py @@ -16,7 +16,7 @@ from __future__ import print_function import numpy as np import math -from base_ops import * +from .base_ops import * import torch import torch.nn as nn @@ -24,20 +24,20 @@ import torch.nn.functional as F class Network(nn.Module): - def __init__(self, spec, args): + def __init__(self, spec, stem_out_channels, num_stacks, num_modules_per_stack, num_labels): super(Network, self).__init__() self.layers = nn.ModuleList([]) in_channels = 3 - out_channels = args.stem_out_channels # out channels for the model stem + out_channels = stem_out_channels # out channels for the model stem # initial stem convolution stem_conv = ConvBnRelu(in_channels, out_channels, 3, 1, 1) self.layers.append(stem_conv) in_channels = out_channels - for stack_num in range(args.num_stacks): + for stack_num in range(num_stacks): if stack_num > 0: # downsampling by maxpool doesn't change the channel downsample = nn.MaxPool2d(kernel_size=2, stride=2) @@ -45,12 +45,12 @@ class Network(nn.Module): out_channels *= 2 - for module_num in range(args.num_modules_per_stack): + for module_num in range(num_modules_per_stack): cell = Cell(spec, in_channels, out_channels) self.layers.append(cell) in_channels = out_channels - self.classifier = nn.Linear(out_channels, args.num_labels) + self.classifier = nn.Linear(out_channels, num_labels) self._initialize_weights() diff --git a/archai/algos/nasbench101/model_builder.py b/archai/algos/nasbench101/model_builder.py index e69de29b..d879aa1d 100644 --- a/archai/algos/nasbench101/model_builder.py +++ b/archai/algos/nasbench101/model_builder.py @@ -0,0 +1,26 @@ +from typing import List + +import torch +from torch import nn + +from .model import Network +from .model_spec import ModelSpec + + +VERTEX_OPS = ['input', 'conv1x1-bn-relu', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'maxpool3x3', 'output'] + +EXAMPLE_DESC_MATRIX = [[0, 1, 1, 1, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 0]] + +def build(desc_matrix:List[List[int]], vertex_ops=VERTEX_OPS, device=None, + stem_out_channels=128, num_stacks=3, num_modules_per_stack=3, num_labels=10)->nn.Module: + model_spec = ModelSpec(desc_matrix, vertex_ops) + model = Network(model_spec, stem_out_channels, num_stacks, num_modules_per_stack, num_labels) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device + model.to(device) + return model \ No newline at end of file diff --git a/archai/algos/nasbench101/model_spec.py b/archai/algos/nasbench101/model_spec.py index 4f43b259..f5bfe4aa 100644 --- a/archai/algos/nasbench101/model_spec.py +++ b/archai/algos/nasbench101/model_spec.py @@ -11,7 +11,7 @@ from __future__ import print_function import copy import numpy as np -import graph_util +from . import graph_util # Graphviz is optional and only required for visualization. try: diff --git a/confs/algos/nasbench101.yaml b/confs/algos/nasbench101.yaml index f22e8bb6..5de022ab 100644 --- a/confs/algos/nasbench101.yaml +++ b/confs/algos/nasbench101.yaml @@ -1,7 +1,7 @@ __include__: "darts.yaml" # just use darts defaults nas: - search: + eval: model_desc: params: { 'cell_matrix' : [[0, 1, 1, 1, 0, 1, 0], @@ -19,4 +19,22 @@ nas: stem_multiplier: 1 # output channels for stem = 128 init_node_ch: 128 # num of input/output channels for nodes in 1st cell model_post_op: 'pool_mean_tensor' - n_cells: 9 # 3 stacks, each stack with 3 cells \ No newline at end of file + n_cells: 9 # 3 stacks, each stack with 3 cells + loader: + aug: '' # additional augmentations to use, for ex, fa_reduced_cifar10, arsaug, autoaug_cifar10, autoaug_extend + cutout: 0 # cutout length, use cutout augmentation when > 0 + train_batch: 128 # 96 is too aggressive for 1080Ti, better set it to 68 + trainer: + aux_weight: 0.0 + drop_path_prob: 0.0 # probability that given edge will be dropped + grad_clip: 5.0 # grads above this value is clipped + epochs: 100 + optimizer: + type: 'sgd' + lr: 0.025 # init learning rate + decay: 1.0e-4 # pytorch default is 0.0 + momentum: 0.9 # pytorch default is 0.0 + nesterov: False # pytorch default is False + lr_schedule: + type: 'cosine' + min_lr: 0.0 # min learning rate to se bet in eta_min param of scheduler diff --git a/scripts/misc/nasbench101_demo.py b/scripts/misc/nasbench101_archai_demo.py similarity index 100% rename from scripts/misc/nasbench101_demo.py rename to scripts/misc/nasbench101_archai_demo.py diff --git a/scripts/misc/nasbench101_pytorch_demo.py b/scripts/misc/nasbench101_pytorch_demo.py new file mode 100644 index 00000000..775895ff --- /dev/null +++ b/scripts/misc/nasbench101_pytorch_demo.py @@ -0,0 +1,23 @@ +from archai.algos.nasbench101 import model_builder +from archai import cifar10_models +from archai.common.trainer import Trainer +from archai.common.config import Config +from archai.common.common import common_init +from archai.datasets import data + + +def main(): + conf = common_init(config_filepath='confs/algos/resnet.yaml') + conf_eval = conf['nas']['eval'] + conf_loader = conf_eval['loader'] + conf_trainer = conf_eval['trainer'] + + model = model_builder.build(model_builder.EXAMPLE_DESC_MATRIX) + + train_dl, _, test_dl = data.get_data(conf_loader) + + trainer = Trainer(conf_trainer, model) + trainer.fit(train_dl, test_dl) + +if __name__ == '__main__': + main() \ No newline at end of file