зеркало из https://github.com/microsoft/archai.git
Refactored hog + small neural network classification pipeline.
This commit is contained in:
Родитель
e5ec907259
Коммит
8009911fb4
|
@ -418,6 +418,14 @@
|
|||
"--aug", "fa_reduced_cifar10"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Train Visual Features FF Net",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${cwd}/scripts/misc/train_visual_features_ffnet.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": []
|
||||
},
|
||||
{
|
||||
"name": "Analysis Aggregate",
|
||||
"type": "python",
|
||||
|
|
|
@ -1,55 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Optional
|
||||
import importlib
|
||||
import sys
|
||||
import string
|
||||
import os
|
||||
|
||||
from overrides import overrides
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from overrides import overrides, EnforceOverrides
|
||||
|
||||
from archai.common.trainer import Trainer
|
||||
from archai.common.config import Config
|
||||
from archai.common.common import logger
|
||||
from archai.datasets import data
|
||||
from archai.nas.model_desc import ModelDesc
|
||||
from archai.nas.model_desc_builder import ModelDescBuilder
|
||||
from archai.nas import nas_utils
|
||||
from archai.common import ml_utils, utils
|
||||
from archai.common.metrics import EpochMetrics, Metrics
|
||||
from archai.nas.model import Model
|
||||
from archai.common.checkpoint import CheckPoint
|
||||
from archai.nas.evaluater import Evaluater
|
||||
from archai.algos.proxynas.freeze_trainer import FreezeTrainer
|
||||
from archai.algos.proxynas.conditional_trainer import ConditionalTrainer
|
||||
|
||||
from nats_bench import create
|
||||
from archai.algos.natsbench.lib.models import get_cell_based_tiny_net
|
||||
|
||||
class LinearEvaluater(Evaluater):
|
||||
@overrides
|
||||
def create_model(self, conf_eval:Config, model_desc_builder:ModelDescBuilder,
|
||||
final_desc_filename=None, full_desc_filename=None)->nn.Module:
|
||||
# region conf vars
|
||||
dataset_name = conf_eval['loader']['dataset']['name']
|
||||
|
||||
# if explicitly passed in then don't get from conf
|
||||
if not final_desc_filename:
|
||||
final_desc_filename = conf_eval['final_desc_filename']
|
||||
|
||||
dataroot = utils.full_path(conf_eval['loader']['dataset']['dataroot'])
|
||||
# endregion
|
||||
|
||||
# create linear model
|
||||
|
||||
# problem: what is the input size of the image features?
|
||||
|
||||
|
||||
return self._model_from_natsbench(arch_index, dataset_name, natsbench_location)
|
|
@ -1,17 +1,22 @@
|
|||
import torch
|
||||
from torch.functional import Tensor
|
||||
import torch.nn as nn
|
||||
|
||||
from skimage import feature
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
|
||||
class VisualFeaturesWithLinearNet(nn.Module):
|
||||
class VisualFeaturesWithFFNet(nn.Module):
|
||||
def __init__(self, feature_len:int, n_classes:int):
|
||||
super(VisualFeaturesWithLinearNet, self).__init__()
|
||||
super(VisualFeaturesWithFFNet, self).__init__()
|
||||
self.feature_len = feature_len
|
||||
|
||||
self.fc = nn.Linear(feature_len, n_classes)
|
||||
self.net = nn.Sequential(nn.Linear(feature_len, 2048),
|
||||
nn.ReLU(),
|
||||
nn.Linear(2048, 2048),
|
||||
nn.ReLU(),
|
||||
nn.Linear(2048, n_classes))
|
||||
|
||||
#self.fc = nn.Linear(feature_len, n_classes)
|
||||
|
||||
|
||||
def _compute_features(self, x:Tensor)->Tensor:
|
||||
|
@ -41,4 +46,4 @@ class VisualFeaturesWithLinearNet(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
feats = self._compute_features(x)
|
||||
return self.fc(feats.float())
|
||||
return self.net(feats.float())
|
|
@ -1,5 +1,12 @@
|
|||
__include__: 'darts.yaml' # just use darts defaults
|
||||
|
||||
dataset:
|
||||
name: 'ImageNet16-120'
|
||||
n_classes: 120
|
||||
channels: 3 # number of channels in image
|
||||
max_batches: -1 # if >= 0 then only these many batches are generated (useful for debugging)
|
||||
storage_name: 'imagenet16' # name of folder or tar file to copy from cloud storage
|
||||
|
||||
common:
|
||||
experiment_name: 'VisualFeaturesWithLinearNet'
|
||||
|
|
@ -2,7 +2,7 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
from archai.networks.visual_features_with_linear_net import VisualFeaturesWithLinearNet
|
||||
from archai.networks.visual_features_with_ff_net import VisualFeaturesWithFFNet
|
||||
from archai.common.trainer import Trainer
|
||||
from archai.common.config import Config
|
||||
from archai.common.common import common_init
|
||||
|
@ -13,7 +13,7 @@ def train_test(conf_eval:Config):
|
|||
conf_trainer = conf_eval['trainer']
|
||||
|
||||
# create model
|
||||
Net = VisualFeaturesWithLinearNet
|
||||
Net = VisualFeaturesWithFFNet
|
||||
feature_len = 324
|
||||
n_classes = 10
|
||||
model = Net(feature_len, n_classes).to(torch.device('cuda', 0))
|
||||
|
@ -27,7 +27,7 @@ def train_test(conf_eval:Config):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
conf = common_init(config_filepath='confs/algos/visual_features_linearnet.yaml')
|
||||
conf = common_init(config_filepath='confs/algos/visual_features_ffnet.yaml')
|
||||
conf_eval = conf['nas']['eval']
|
||||
|
||||
train_test(conf_eval)
|
|
@ -32,6 +32,8 @@ natsbench_cifar10:
|
|||
'ft_fb256_ftlr0.1_fte5_ct256_ftt0.6_c9',
|
||||
'ft_fb256_ftlr0.1_fte10_ct256_ftt0.6_c9']
|
||||
|
||||
zero_cost: ['zero_cost_nb201_cifar10']
|
||||
|
||||
shortreg: ['nb_reg_b1024_e01',
|
||||
'nb_reg_b1024_e02',
|
||||
'nb_reg_b1024_e04',
|
||||
|
@ -67,6 +69,8 @@ natsbench_cifar100:
|
|||
'nb_c100_reg_b2048_e20',
|
||||
'nb_c100_reg_b2048_e30']
|
||||
|
||||
zero_cost: ['zero_cost_nb201_cifar100']
|
||||
|
||||
natsbench_imagenet16-120:
|
||||
freezetrain: [ft_i6_fb2048_ftlr0.1_fte5_ct256_ftt0.1,
|
||||
ft_i6_fb2048_ftlr0.1_fte10_ct256_ftt0.1,
|
||||
|
@ -84,6 +88,8 @@ natsbench_imagenet16-120:
|
|||
ft_i6_fb512_ftlr0.1_fte10_ct256_ftt0.2,
|
||||
ft_i6_fb256_ftlr0.1_fte5_ct256_ftt0.2]
|
||||
|
||||
zero_cost: ['zero_cost_nb201_ImageNet16-120']
|
||||
|
||||
shortreg: [nb_i16_reg_b256_e10,
|
||||
nb_i16_reg_b256_e20,
|
||||
nb_i16_reg_b256_e30,
|
||||
|
|
|
@ -14,6 +14,8 @@ import plotly.express as px
|
|||
from plotly.subplots import make_subplots
|
||||
import plotly.graph_objects as go
|
||||
|
||||
ZERO_COST_MEASURES = ['fisher', 'grad_norm', 'grasp', 'jacob_cov', 'plain', 'snip', 'synflow', 'synflow_bn']
|
||||
|
||||
|
||||
def parse_raw_data(root_exp_folder:str, exp_list:List[str])->Dict:
|
||||
data = {}
|
||||
|
@ -26,7 +28,7 @@ def parse_raw_data(root_exp_folder:str, exp_list:List[str])->Dict:
|
|||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Cross Experiment Plots')
|
||||
parser.add_argument('--dataset', type=str, default='nasbench101',
|
||||
parser.add_argument('--dataset', type=str, default='natsbench_imagenet16-120',
|
||||
help='dataset on which experiments have been run')
|
||||
parser.add_argument('--conf_location', type=str, default='scripts/reports/proxynas_plots/cross_exp_conf.yaml',
|
||||
help='location of conf file')
|
||||
|
@ -39,10 +41,12 @@ def main():
|
|||
|
||||
exp_list = conf_data[args.dataset]['freezetrain']
|
||||
shortreg_exp_list = conf_data[args.dataset]['shortreg']
|
||||
zero_cost_exp_list = conf_data[args.dataset]['zero_cost']
|
||||
|
||||
# parse raw data from all processed experiments
|
||||
data = parse_raw_data(exp_folder, exp_list)
|
||||
shortreg_data = parse_raw_data(exp_folder, shortreg_exp_list)
|
||||
zero_cost_data = parse_raw_data(exp_folder, zero_cost_exp_list)
|
||||
|
||||
# collect linestyles and colors to create distinguishable plots
|
||||
cmap = plt.get_cmap('tab20')
|
||||
|
@ -126,6 +130,15 @@ def main():
|
|||
if j == 0:
|
||||
break
|
||||
|
||||
# get zero cost measures
|
||||
for j, key in enumerate(zero_cost_data.keys()):
|
||||
assert tp == zero_cost_data[key]['top_percents'][i]
|
||||
for measure in ZERO_COST_MEASURES:
|
||||
spe_name = measure + '_spe'
|
||||
cr_name = measure + '_ratio_common'
|
||||
this_tp_info['zero_cost_' + measure] = (0.0, zero_cost_data[key][spe_name][i], zero_cost_data[key][cr_name][i])
|
||||
|
||||
|
||||
# get shortreg
|
||||
for key in shortreg_data.keys():
|
||||
exp_name = key
|
||||
|
@ -145,6 +158,7 @@ def main():
|
|||
for ind, tp_key in enumerate(tp_info.keys()):
|
||||
counter = 0
|
||||
counter_reg = 0
|
||||
counter_zero = 0
|
||||
for exp in tp_info[tp_key].keys():
|
||||
duration = tp_info[tp_key][exp][0]
|
||||
spe = tp_info[tp_key][exp][1]
|
||||
|
@ -158,6 +172,10 @@ def main():
|
|||
marker = counter_reg
|
||||
marker_color = 'blue'
|
||||
counter_reg += 1
|
||||
elif 'zero_cost' in exp:
|
||||
marker = counter_zero
|
||||
marker_color = 'green'
|
||||
counter_zero += 1
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче