зеркало из https://github.com/microsoft/archai.git
diversenad branch merge
This commit is contained in:
Родитель
08c5fc78ed
Коммит
fe8d3efc45
|
@ -109,6 +109,31 @@
|
|||
"console": "integratedTerminal",
|
||||
"args": ["--algos", "xnas"]
|
||||
},
|
||||
{
|
||||
"name": "Divnas-Full",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${cwd}/scripts/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["--full", "--algos", "divnas"]
|
||||
},
|
||||
{
|
||||
"name": "Divnas-Search-Toy",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${cwd}/scripts/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["--no-eval", "--algos", "divnas"]
|
||||
},
|
||||
{
|
||||
"name": "Divnas-E2E-Toy",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${cwd}/scripts/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["--algos", "divnas"]
|
||||
},
|
||||
|
||||
{
|
||||
"name": "Gs-Full",
|
||||
"type": "python",
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
from typing import Iterable, Optional, Tuple
|
||||
from typing import Iterable, Optional, Tuple, List
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import numpy as np
|
||||
import h5py
|
||||
|
||||
from overrides import overrides
|
||||
|
||||
from archai.nas.model_desc import OpDesc
|
||||
|
@ -43,10 +47,10 @@ class MixedOp(Op):
|
|||
OpDesc(primitive, op_desc.params, in_len=1, trainables=None),
|
||||
affine=affine, alphas=alphas)
|
||||
self._ops.append(op)
|
||||
|
||||
|
||||
@overrides
|
||||
def forward(self, x):
|
||||
asm = F.softmax(self._alphas[0], dim=0)
|
||||
asm = F.softmax(self._alphas[0], dim=0)
|
||||
return sum(w * op(x) for w, op in zip(asm, self._ops))
|
||||
|
||||
@overrides
|
||||
|
@ -72,6 +76,12 @@ class MixedOp(Op):
|
|||
def can_drop_path(self) -> bool:
|
||||
return False
|
||||
|
||||
def get_op_desc(self, index:int)->OpDesc:
|
||||
''' index: index in the primitives list '''
|
||||
assert index < len(self.PRIMITIVES)
|
||||
desc, _ = self._ops[index].finalize()
|
||||
return desc
|
||||
|
||||
def _set_alphas(self, alphas: Iterable[nn.Parameter]) -> None:
|
||||
# must call before adding other ops
|
||||
assert len(list(self.parameters())) == 0
|
||||
|
@ -84,4 +94,4 @@ class MixedOp(Op):
|
|||
# asks back for the parameters in the object from Pytorch
|
||||
# which automagically registers the just created parameters.
|
||||
self._reg_alphas = new_p
|
||||
self._alphas = [p for p in self.parameters()]
|
||||
self._alphas = [p for p in self.parameters()]
|
|
@ -0,0 +1,452 @@
|
|||
import numpy as np
|
||||
import pdb
|
||||
from collections import defaultdict
|
||||
from itertools import combinations_with_replacement
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
import math as ma
|
||||
import h5py
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import List, Set, Dict, Tuple, Any, Callable
|
||||
from tqdm import tqdm
|
||||
from itertools import permutations, combinations
|
||||
|
||||
from archai.algos.divnas.seqopt import SeqOpt
|
||||
|
||||
|
||||
def create_submod_f(covariance:np.array)->Callable:
|
||||
def compute_marginal_gain_func(item:int, sub_sel:List[int], S:Set[int]):
|
||||
assert covariance.shape[0] == covariance.shape[1]
|
||||
assert len(covariance.shape) == 2
|
||||
assert len(S) == covariance.shape[0]
|
||||
|
||||
sel_set = set(sub_sel)
|
||||
marg_gain = compute_marginal_gain(item, sel_set, S, covariance)
|
||||
return marg_gain
|
||||
return compute_marginal_gain_func
|
||||
|
||||
|
||||
def get_batch(feature_list, batch_size, i):
|
||||
start_row = batch_size * i
|
||||
end_row = start_row + batch_size
|
||||
feats = [feat[start_row:end_row, :] for feat in feature_list]
|
||||
return feats
|
||||
|
||||
|
||||
def rbf(x:np.array, y:np.array, sigma=0.1)->np.array:
|
||||
""" Computes the rbf kernel between two input vectors """
|
||||
|
||||
# make sure that inputs are vectors
|
||||
assert len(x.shape) == 1
|
||||
assert len(y.shape) == 1
|
||||
|
||||
sq_euclidean = np.sum(np.square(x-y))
|
||||
k = np.exp(-sq_euclidean/(2*sigma*sigma))
|
||||
return k
|
||||
|
||||
|
||||
def _compute_mi(cov_kernel:np.array, A:Set, V_minus_A:Set):
|
||||
sigma_A = cov_kernel[np.ix_(list(A), list(A))]
|
||||
sigma_V_minus_A = cov_kernel[np.ix_(list(V_minus_A), list(V_minus_A))]
|
||||
I = 0.5 * np.log(np.linalg.det(sigma_A) * np.linalg.det(sigma_V_minus_A) / np.linalg.det(cov_kernel))
|
||||
return I
|
||||
|
||||
|
||||
def compute_brute_force_sol(cov_kernel:np.array, budget:int)->Tuple[Tuple[Any], float]:
|
||||
|
||||
assert cov_kernel.shape[0] == cov_kernel.shape[1]
|
||||
assert len(cov_kernel.shape) == 2
|
||||
assert budget > 0 and budget <= cov_kernel.shape[0]
|
||||
|
||||
V = set(range(cov_kernel.shape[0]))
|
||||
|
||||
# for each combination of budgeted items compute its mutual
|
||||
# information with the complement set
|
||||
mis = []
|
||||
for subset in combinations(range(cov_kernel.shape[0]), budget):
|
||||
A = set(subset)
|
||||
V_minus_A = V - A
|
||||
I = _compute_mi(cov_kernel, A, V_minus_A)
|
||||
mis.append((subset, I))
|
||||
|
||||
# find the maximum subset
|
||||
max_subset, mi = max(mis, key = lambda x: x[1])
|
||||
return max_subset, mi
|
||||
|
||||
|
||||
|
||||
def compute_correlation(covariance:np.array)->np.array:
|
||||
variance = np.diag(covariance).reshape(-1, 1)
|
||||
stds = np.sqrt(np.matmul(variance, variance.T))
|
||||
correlation = covariance / (stds + 1e-16)
|
||||
return correlation
|
||||
|
||||
|
||||
def compute_covariance_offline(feature_list:List[np.array])->np.array:
|
||||
"""Compute covariance matrix for high-dimensional features.
|
||||
feature_shape: (num_samples, feature_dim)
|
||||
"""
|
||||
num_features = len(feature_list)
|
||||
num_samples = feature_list[0].shape[0]
|
||||
flatten_features = [
|
||||
feas.reshape(num_samples, -1) for feas in feature_list]
|
||||
unbiased_features = [
|
||||
feas - np.mean(feas, 0) for feas in flatten_features]
|
||||
# (num_samples, feature_dim, num_features)
|
||||
features = np.stack(unbiased_features, -1)
|
||||
covariance = np.zeros((num_features, num_features), np.float32)
|
||||
for i in range(num_samples):
|
||||
covariance += np.matmul(features[i].T, features[i])
|
||||
return covariance
|
||||
|
||||
|
||||
def compute_rbf_kernel_covariance(feature_list:List[np.array], sigma=0.1)->np.array:
|
||||
""" Compute rbf kernel covariance for high dimensional features.
|
||||
feature_list: List of features each of shape: (num_samples, feature_dim)
|
||||
sigma: sigma of the rbf kernel """
|
||||
num_features = len(feature_list)
|
||||
covariance = np.zeros((num_features, num_features), np.float32)
|
||||
|
||||
for i in range(num_features):
|
||||
for j in range(num_features):
|
||||
if i == j:
|
||||
covariance[i][j] = covariance[j][i] = 1.0
|
||||
continue
|
||||
|
||||
# NOTE: one could try to take all pairs rbf responses
|
||||
# but that is too much computation and probably does
|
||||
# not add much information
|
||||
feats_i = feature_list[i]
|
||||
feats_j = feature_list[j]
|
||||
assert feats_i.shape == feats_j.shape
|
||||
|
||||
rbfs = np.exp(-np.sum(np.square(feats_i - feats_j), axis=1) / (2*sigma*sigma))
|
||||
avg_cov = np.sum(rbfs)/feats_i.shape[0]
|
||||
covariance[i][j] = covariance[j][i] = avg_cov
|
||||
|
||||
return covariance
|
||||
|
||||
|
||||
def compute_euclidean_dist_quantiles(feature_list:List[np.array], subsamplefactor=1)->List[Tuple[float, float]]:
|
||||
""" Compute quantile distances between feature pairs
|
||||
feature_list: List of features each of shape: (num_samples, feature_dim)
|
||||
"""
|
||||
num_features = len(feature_list)
|
||||
num_samples = feature_list[0].shape[0]
|
||||
# (num_samples, feature_dim, num_features)
|
||||
features = np.stack(feature_list, -1)
|
||||
|
||||
# compute all pairwise feature distances
|
||||
# too slow need to vectorize asap
|
||||
distances = []
|
||||
for i in range(num_features):
|
||||
for j in range(num_features):
|
||||
if i == j:
|
||||
continue
|
||||
|
||||
for k in range(0, num_samples, subsamplefactor):
|
||||
feat_i = features[k, :][:, i]
|
||||
feat_j = features[k, :][:, j]
|
||||
dist = np.sqrt(np.sum(np.square(feat_i-feat_j)))
|
||||
distances.append(dist)
|
||||
|
||||
quantiles = [i*0.1 for i in range(1, 10)]
|
||||
quant_vals = np.quantile(distances, quantiles)
|
||||
quants = []
|
||||
for quant, val in zip(quantiles, quant_vals.tolist()):
|
||||
quants.append((quant, val))
|
||||
return quants
|
||||
|
||||
|
||||
def greedy_op_selection(covariance:np.array, k:int)->List[int]:
|
||||
assert covariance.shape[0] == covariance.shape[1]
|
||||
assert len(covariance.shape) == 2
|
||||
assert k <= covariance.shape[0]
|
||||
|
||||
A = set()
|
||||
# to keep order information
|
||||
A_list = []
|
||||
|
||||
S = set()
|
||||
for i in range(covariance.shape[0]):
|
||||
S.add(i)
|
||||
|
||||
for i in tqdm(range(k)):
|
||||
marginal_gains = []
|
||||
marginal_gain_ids = []
|
||||
for y in S - A:
|
||||
delta_y = compute_marginal_gain(y, A, S, covariance)
|
||||
marginal_gains.append(delta_y)
|
||||
marginal_gain_ids.append(y)
|
||||
|
||||
val = -ma.inf
|
||||
argmax = -1
|
||||
for marg_gain, marg_gain_id in zip(marginal_gains, marginal_gain_ids):
|
||||
if marg_gain > val:
|
||||
val = marg_gain
|
||||
argmax = marg_gain_id
|
||||
|
||||
A.add(argmax)
|
||||
A_list.append(argmax)
|
||||
|
||||
return A_list
|
||||
|
||||
|
||||
def compute_marginal_gain(y:int, A:Set[int], S:Set[int], covariance:np.array)->float:
|
||||
|
||||
if A:
|
||||
A_copy = deepcopy(A)
|
||||
A_copy.add(y)
|
||||
else:
|
||||
A_copy = set()
|
||||
A_copy.add(y)
|
||||
|
||||
A_bar = S - A_copy
|
||||
|
||||
sigma_y_sqr = covariance[y, y]
|
||||
|
||||
if A:
|
||||
sigma_AA = covariance[np.ix_(list(A), list(A))]
|
||||
sigma_yA = covariance[np.ix_([y], list(A))]
|
||||
numerator = sigma_y_sqr - np.matmul(sigma_yA, np.matmul(np.linalg.inv(sigma_AA), sigma_yA.T))
|
||||
else:
|
||||
numerator = sigma_y_sqr
|
||||
|
||||
if A_bar:
|
||||
sigma_AA_bar = covariance[np.ix_(list(A_bar), list(A_bar))]
|
||||
sigma_yA_bar = covariance[np.ix_([y], list(A_bar))]
|
||||
denominator = sigma_y_sqr - np.matmul(sigma_yA_bar, np.matmul(np.linalg.inv(sigma_AA_bar), sigma_yA_bar.T))
|
||||
else:
|
||||
denominator = sigma_y_sqr
|
||||
|
||||
gain = numerator/denominator
|
||||
return float(gain)
|
||||
|
||||
|
||||
def collect_features(rootfolder:str, subsampling_factor:int = 1)->Dict[str, List[np.array]]:
|
||||
""" Walks the rootfolder for h5py files and loads them into the format
|
||||
required for analysis.
|
||||
|
||||
Inputs:
|
||||
|
||||
rootfolder: full path to folder containing h5 files which have activations
|
||||
subsampling_factor: every nth minibatch will be loaded to keep memory manageable
|
||||
|
||||
Outputs:
|
||||
|
||||
dictionary with edge name strings as keys and values are lists of np.array [num_samples, feature_dim]
|
||||
"""
|
||||
|
||||
assert subsampling_factor > 0
|
||||
|
||||
# gather all h5 files
|
||||
h5files = [os.path.join(rootfolder, f) for f in os.listdir(rootfolder) if os.path.isfile(os.path.join(rootfolder, f)) and '.h5' in f]
|
||||
assert h5files
|
||||
|
||||
|
||||
# storage for holding activations for all edges
|
||||
all_edges_activs = defaultdict(list)
|
||||
|
||||
for h5file in h5files:
|
||||
with h5py.File(h5file, 'r') as hf:
|
||||
edge_name = h5file.split('/')[-1].split('.')[-2]
|
||||
edge_activ_list = []
|
||||
|
||||
# load all batches
|
||||
keys_list = list(hf.keys())
|
||||
print(f'processing {h5file}, num batches {len(keys_list)}')
|
||||
for i in range(0, len(keys_list), subsampling_factor):
|
||||
key = keys_list[i]
|
||||
payload = np.array(hf.get(key))
|
||||
edge_activ_list.append(payload)
|
||||
|
||||
obsv_dict = defaultdict(list)
|
||||
# separate activations by ops
|
||||
for batch in edge_activ_list:
|
||||
# assumption (num_ops, batch_size, x, y, z)
|
||||
for op in range(batch.shape[0]):
|
||||
for b in range(batch.shape[1]):
|
||||
feat = batch[op][b]
|
||||
feat = feat.flatten()
|
||||
obsv_dict[op].append(feat)
|
||||
|
||||
num_ops = edge_activ_list[0].shape[0]
|
||||
feature_list = [np.zeros(1) for _ in range(num_ops)]
|
||||
for key in obsv_dict.keys():
|
||||
feat = np.array(obsv_dict[key])
|
||||
feature_list[key] = feat
|
||||
|
||||
# removing none and skip_connect
|
||||
del feature_list[-1]
|
||||
del feature_list[2]
|
||||
|
||||
all_edges_activs[edge_name] = feature_list
|
||||
|
||||
return all_edges_activs
|
||||
|
||||
|
||||
def plot_all_covs(covs_kernel, corr, primitives, axs):
|
||||
assert axs.shape[0] * axs.shape[1] == len(covs_kernel) + 1
|
||||
flat_axs = axs.flatten()
|
||||
|
||||
for i, quantile in enumerate(covs_kernel.keys()):
|
||||
cov = covs_kernel[quantile]
|
||||
sns.heatmap(cov, annot=True, fmt='.1g', cmap='coolwarm', xticklabels=primitives, yticklabels=primitives, ax=flat_axs[i])
|
||||
flat_axs[i].set_title(f'Kernel covariance sigma={quantile} quantile')
|
||||
|
||||
sns.heatmap(corr, annot=True, fmt='.1g', cmap='coolwarm', xticklabels=primitives, yticklabels=primitives, ax=flat_axs[-1])
|
||||
flat_axs[-1].set_title(f'Correlation')
|
||||
|
||||
def main():
|
||||
|
||||
rootfile = '/media/dedey/DATADRIVE1/activations'
|
||||
|
||||
all_edges_activs = collect_features(rootfile, subsampling_factor=5)
|
||||
|
||||
PRIMITIVES = [
|
||||
'max_pool_3x3',
|
||||
'avg_pool_3x3',
|
||||
'sep_conv_3x3',
|
||||
'sep_conv_5x5',
|
||||
'dil_conv_3x3',
|
||||
'dil_conv_5x5',
|
||||
]
|
||||
|
||||
# # Use all edges
|
||||
# all_edges_list = []
|
||||
# all_names_list = []
|
||||
# for i in all_edges_activs.keys():
|
||||
# all_edges_list.extend(all_edges_activs[i])
|
||||
# for prim in PRIMITIVES:
|
||||
# all_names_list.append(i + '_' + prim)
|
||||
|
||||
# Use specific edges
|
||||
all_edges_list = []
|
||||
all_names_list = []
|
||||
|
||||
# edge_list = ['activations_node_0_edge_0']
|
||||
edge_list = ['activations_node_0_edge_0', 'activations_node_0_edge_1']
|
||||
# edge_list = ['activations_node_1_edge_0', 'activations_node_1_edge_1', 'activations_node_1_edge_2']
|
||||
# edge_list = ['activations_node_2_edge_0', 'activations_node_2_edge_1', 'activations_node_2_edge_2', 'activations_node_2_edge_3']
|
||||
# edge_list = ['activations_node_3_edge_0', 'activations_node_3_edge_1', 'activations_node_3_edge_2', 'activations_node_3_edge_3', 'activations_node_3_edge_4']
|
||||
for name in edge_list:
|
||||
all_edges_list.extend(all_edges_activs[name])
|
||||
for prim in PRIMITIVES:
|
||||
all_names_list.append(name + '_' + prim)
|
||||
|
||||
|
||||
# compute covariance like usual
|
||||
# cov = compute_covariance_offline(all_edges_list)
|
||||
# corr = compute_correlation(cov)
|
||||
# sns.heatmap(corr, annot=False, xticklabels=all_names_list, yticklabels=all_names_list, cmap='coolwarm')
|
||||
# plt.axis('equal')
|
||||
# plt.show()
|
||||
|
||||
# compute kernel covariance
|
||||
# quants = compute_euclidean_dist_quantiles(all_edges_list, subsamplefactor=20)
|
||||
cov_kernel_orig = compute_rbf_kernel_covariance(all_edges_list, sigma=168)
|
||||
cov_kernel = cov_kernel_orig + 1.0*np.eye(cov_kernel_orig.shape[0])
|
||||
print(f'Det before diag addition {np.linalg.det(cov_kernel_orig)}')
|
||||
print(f'Det after diag addition {np.linalg.det(cov_kernel)}')
|
||||
print(f'Condition number is {np.linalg.cond(cov_kernel)}')
|
||||
sns.heatmap(cov_kernel, annot=False, xticklabels=all_names_list, yticklabels=all_names_list, cmap='coolwarm')
|
||||
plt.axis('equal')
|
||||
plt.show()
|
||||
|
||||
# brute force solution
|
||||
budget = 4
|
||||
bf_sensors, bf_val = compute_brute_force_sol(cov_kernel_orig, budget)
|
||||
print(f'Brute force max subset {bf_sensors}, max mi {bf_val}')
|
||||
|
||||
# greedy
|
||||
print('Greedy selection')
|
||||
greedy_ops = greedy_op_selection(cov_kernel, cov_kernel.shape[0])
|
||||
|
||||
for i, op_index in enumerate(greedy_ops):
|
||||
print(f'Greedy op {i} is {all_names_list[op_index]}')
|
||||
|
||||
greedy_budget = greedy_ops[:budget]
|
||||
# find MI of the greedy solution
|
||||
V = set(range(cov_kernel.shape[0]))
|
||||
A_greedy = set(greedy_budget)
|
||||
V_minus_A_greedy = V - A_greedy
|
||||
I_greedy = _compute_mi(cov_kernel_orig, A_greedy, V_minus_A_greedy)
|
||||
print(f'Greedy solution is {greedy_budget}, mi is {I_greedy}')
|
||||
|
||||
# seqopt
|
||||
# simulated batch size
|
||||
batch_size = 64
|
||||
num_batches = int(all_edges_list[0].shape[0] / batch_size)
|
||||
|
||||
# seqopt object that will get updated in an online manner
|
||||
num_items = cov_kernel.shape[0]
|
||||
eps = 0.1
|
||||
seqopt = SeqOpt(num_items, eps)
|
||||
|
||||
for i in tqdm(range(num_batches)):
|
||||
# simulate getting a new batch of activations
|
||||
sample = get_batch(all_edges_list, batch_size, i)
|
||||
|
||||
# sample a list of activations from seqopt
|
||||
sel_list = seqopt.sample_sequence(with_replacement=False)
|
||||
|
||||
# Using 50th percentile distance
|
||||
sigma = 168.0
|
||||
cov = compute_rbf_kernel_covariance(sample, sigma=sigma)
|
||||
|
||||
# update seqopt
|
||||
compute_marginal_gain_func = create_submod_f(cov)
|
||||
seqopt.update(sel_list, compute_marginal_gain_func)
|
||||
|
||||
# now sample a list of ops and hope it is diverse
|
||||
sel_list = seqopt.sample_sequence(with_replacement=False)
|
||||
# sel_primitives = [all_names_list for i in sel_list]
|
||||
# print(f'SeqOpt selected primitives are {sel_primitives}')
|
||||
|
||||
# check that it is close to greedy and or bruteforce
|
||||
budget = 4
|
||||
sel_list = sel_list[:budget]
|
||||
# find MI of the greedy solution
|
||||
V = set(range(num_items))
|
||||
A_seqopt = set(sel_list)
|
||||
V_minus_A_seqopt = V - A_seqopt
|
||||
I_seqopt = _compute_mi(cov_kernel_orig, A_seqopt, V_minus_A_seqopt)
|
||||
print(f'SeqOpt solution is {sel_list}, mi is {I_seqopt}')
|
||||
|
||||
|
||||
# # For enumerating through many choices of rbf sigmas
|
||||
# covs_kernel = {}
|
||||
# for quantile, val in quants:
|
||||
# print(f'Computing kernel covariance for quantile {quantile}')
|
||||
# cov_kernel = compute_rbf_kernel_covariance(all_edges_list, sigma=val)
|
||||
# covs_kernel[quantile] = cov_kernel
|
||||
|
||||
# # compute greedy sequence of ops on one of the kernels
|
||||
# print('Greedy selection')
|
||||
# greedy_ops = greedy_op_selection(covs_kernel[0.5], 3)
|
||||
|
||||
# for i, op_index in enumerate(greedy_ops):
|
||||
# print(f'Greedy op {i} is {all_names_list[op_index]}')
|
||||
|
||||
|
||||
# fig, axs = plt.subplots(5, 2)
|
||||
# plot_all_covs(covs_kernel, corr, all_names_list, axs)
|
||||
# plt.show()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,107 @@
|
|||
from collections import defaultdict
|
||||
from typing import Callable, Iterable, List, Optional, Tuple, Dict
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from overrides import overrides, EnforceOverrides
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch import nn, tensor
|
||||
from overrides import overrides, EnforceOverrides
|
||||
|
||||
import archai.algos.divnas.analyse_activations as aa
|
||||
|
||||
from archai.nas.cell import Cell
|
||||
|
||||
|
||||
class Divnas_Cell():
|
||||
''' Wrapper cell class for divnas specific modifications '''
|
||||
def __init__(self, cell:Cell):
|
||||
|
||||
self._cell = cell
|
||||
|
||||
self._collect_activations = False
|
||||
self._edgeoptype = None
|
||||
self._sigma = None
|
||||
self._counter = 0
|
||||
self.node_covs:Dict[int, np.array] = {}
|
||||
|
||||
def collect_activations(self, edgeoptype, sigma:float)->None:
|
||||
self._collect_activations = True
|
||||
self._edgeoptype = edgeoptype
|
||||
self._sigma = sigma
|
||||
|
||||
# go through all edges in the DAG and if they are of edgeoptype
|
||||
# type then set them to collect activations
|
||||
for i, node in enumerate(self._cell.dag):
|
||||
|
||||
# initialize the covariance matrix for this node
|
||||
num_ops = 0
|
||||
for edge in node:
|
||||
if hasattr(edge._op, 'PRIMITIVES') and type(edge._op) == self._edgeoptype:
|
||||
num_ops += edge._op.num_valid_div_ops
|
||||
edge._op.collect_activations = True
|
||||
|
||||
self.node_covs[id(node)] = np.zeros((num_ops, num_ops))
|
||||
|
||||
|
||||
def update_covs(self):
|
||||
assert self._collect_activations
|
||||
|
||||
for _, node in enumerate(self._cell.dag):
|
||||
# TODO: conver to explicit ordering
|
||||
all_activs = []
|
||||
for j, edge in enumerate(node):
|
||||
if type(edge._op) == self._edgeoptype:
|
||||
activs = edge._op.activations
|
||||
all_activs.append(activs)
|
||||
# update covariance matrix
|
||||
activs_converted = self._convert_activations(all_activs)
|
||||
new_cov = aa.compute_rbf_kernel_covariance(activs_converted, sigma=self._sigma)
|
||||
updated_cov = (self._counter * self.node_covs[id(node)] + new_cov) / (self._counter + 1)
|
||||
self.node_covs[id(node)] = updated_cov
|
||||
|
||||
|
||||
def clear_collect_activations(self):
|
||||
for _, node in enumerate(self._cell.dag):
|
||||
for edge in node:
|
||||
if hasattr(edge._op, 'PRIMITIVES') and type(edge._op) == self._edgeoptype:
|
||||
edge._op.collect_activations = False
|
||||
|
||||
self._collect_activations = False
|
||||
self._edgeoptype = None
|
||||
self._sigma = None
|
||||
self._node_covs = {}
|
||||
|
||||
|
||||
def _convert_activations(self, all_activs:List[List[np.array]])->List[np.array]:
|
||||
''' Converts to the format needed by covariance computing functions
|
||||
Input all_activs: List[List[np.array]]. Outer list len is num_edges.
|
||||
Inner list is of num_ops length. Each element in inner list is [batch_size, x, y, z] '''
|
||||
|
||||
num_ops = len(all_activs[0])
|
||||
for activs in all_activs:
|
||||
assert num_ops == len(activs)
|
||||
|
||||
all_edge_list = []
|
||||
|
||||
for edge in all_activs:
|
||||
obsv_dict = defaultdict(list)
|
||||
# assumption edge_np will be (num_ops, batch_size, x, y, z)
|
||||
edge_np = np.array(edge)
|
||||
for op in range(edge_np.shape[0]):
|
||||
for b in range(edge_np.shape[1]):
|
||||
feat = edge_np[op][b]
|
||||
feat = feat.flatten()
|
||||
obsv_dict[op].append(feat)
|
||||
|
||||
feature_list = [*range(num_ops)]
|
||||
for key in obsv_dict.keys():
|
||||
feat = np.array(obsv_dict[key])
|
||||
feature_list[key] = feat
|
||||
|
||||
all_edge_list.extend(feature_list)
|
||||
|
||||
return all_edge_list
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
from overrides import overrides
|
||||
|
||||
from archai.nas.cell_builder import CellBuilder
|
||||
from archai.nas.operations import Op
|
||||
from archai.nas.model_desc import ModelDesc, CellDesc, CellType, OpDesc, EdgeDesc
|
||||
from .divop import DivOp
|
||||
|
||||
class DivnasCellBuilder(CellBuilder):
|
||||
@overrides
|
||||
def register_ops(self) -> None:
|
||||
Op.register_op('div_op',
|
||||
lambda op_desc, alphas, affine:
|
||||
DivOp(op_desc, alphas, affine))
|
||||
|
||||
@overrides
|
||||
def build(self, model_desc:ModelDesc, search_iter:int)->None:
|
||||
for cell_desc in model_desc.cell_descs():
|
||||
self._build_cell(cell_desc)
|
||||
|
||||
def _build_cell(self, cell_desc:CellDesc)->None:
|
||||
reduction = (cell_desc.cell_type==CellType.Reduction)
|
||||
|
||||
# add div op for each edge in each node
|
||||
# how does the stride works? For all ops connected to s0 and s1, we apply
|
||||
# reduction in WxH. All ops connected elsewhere automatically gets
|
||||
# reduced WxH (because all subsequent states are derived from s0 and s1).
|
||||
# Note that channel is increased via conv_params for the cell
|
||||
for i, node in enumerate(cell_desc.nodes()):
|
||||
for j in range(i+2):
|
||||
op_desc = OpDesc('div_op',
|
||||
params={
|
||||
'conv': cell_desc.conv_params,
|
||||
'stride': 2 if reduction and j < 2 else 1
|
||||
}, in_len=1, trainables=None, children=None)
|
||||
edge = EdgeDesc(op_desc, input_ids=[j])
|
||||
node.edges.append(edge)
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
from typing import Type
|
||||
|
||||
from overrides import overrides
|
||||
|
||||
from archai.common.common import get_conf
|
||||
from archai.nas.exp_runner import ExperimentRunner
|
||||
from archai.nas.arch_trainer import ArchTrainer, TArchTrainer
|
||||
from archai.algos.darts.bilevel_arch_trainer import BilevelArchTrainer
|
||||
from archai.algos.gumbelsoftmax.gs_arch_trainer import GsArchTrainer
|
||||
from .divnas_cell_builder import DivnasCellBuilder
|
||||
from .divnas_finalizers import DivnasFinalizers
|
||||
from archai.nas.finalizers import Finalizers
|
||||
|
||||
class DivnasExperimentRunner(ExperimentRunner):
|
||||
|
||||
@overrides
|
||||
def cell_builder(self)->DivnasCellBuilder:
|
||||
return DivnasCellBuilder()
|
||||
|
||||
@overrides
|
||||
def trainer_class(self)->TArchTrainer:
|
||||
conf = get_conf()
|
||||
trainer = conf['nas']['search']['divnas']['archtrainer']
|
||||
|
||||
if trainer == 'bilevel':
|
||||
return BilevelArchTrainer
|
||||
elif trainer == 'noalpha':
|
||||
return ArchTrainer
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@overrides
|
||||
def finalizers(self)->Finalizers:
|
||||
conf = get_conf()
|
||||
finalizer = conf['nas']['search']['finalizer']
|
||||
|
||||
if finalizer == 'mi':
|
||||
return DivnasFinalizers()
|
||||
else:
|
||||
return super().finalizers()
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,134 @@
|
|||
from typing import List, Tuple, Optional, Iterator, Dict
|
||||
from overrides import overrides
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
import numpy as np
|
||||
|
||||
from archai.common.common import get_conf
|
||||
from archai.common.common import logger
|
||||
from archai.datasets.data import get_data
|
||||
from archai.nas.model import Model
|
||||
from archai.nas.cell import Cell
|
||||
from archai.nas.model_desc import CellDesc, ModelDesc, NodeDesc, EdgeDesc
|
||||
from archai.nas.finalizers import Finalizers
|
||||
from archai.algos.divnas.analyse_activations import compute_brute_force_sol
|
||||
from archai.algos.divnas.divop import DivOp
|
||||
from .divnas_cell import Divnas_Cell
|
||||
|
||||
class DivnasFinalizers(Finalizers):
|
||||
|
||||
@overrides
|
||||
def finalize_model(self, model: Model, to_cpu=True, restore_device=True) -> ModelDesc:
|
||||
|
||||
logger.pushd('finalize')
|
||||
|
||||
# get config and train data loader
|
||||
# TODO: confirm this is correct in case you get silent bugs
|
||||
conf = get_conf()
|
||||
conf_loader = conf['nas']['search']['loader']
|
||||
train_dl, val_dl, test_dl = get_data(conf_loader)
|
||||
|
||||
# wrap all cells in the model
|
||||
self._divnas_cells:Dict[int, Divnas_Cell] = {}
|
||||
for _, cell in enumerate(model.cells):
|
||||
divnas_cell = Divnas_Cell(cell)
|
||||
self._divnas_cells[id(cell)] = divnas_cell
|
||||
|
||||
# go through all edges in the DAG and if they are of divop
|
||||
# type then set them to collect activations
|
||||
sigma = conf['nas']['search']['divnas']['sigma']
|
||||
for _, dcell in enumerate(self._divnas_cells.values()):
|
||||
dcell.collect_activations(DivOp, sigma)
|
||||
|
||||
# now we need to run one evaluation epoch to collect activations
|
||||
# we do it on cpu otherwise we might run into memory issues
|
||||
# later we can redo the whole logic in pytorch itself
|
||||
# at the end of this each node in a cell will have the covariance
|
||||
# matrix of all incoming edges' ops
|
||||
model = model.cpu()
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for _ in range(1):
|
||||
for _, (x, _) in enumerate(train_dl):
|
||||
_, _ = model(x), None
|
||||
# now you can go through and update the
|
||||
# node covariances in every cell
|
||||
for dcell in self._divnas_cells.values():
|
||||
dcell.update_covs()
|
||||
|
||||
logger.popd()
|
||||
|
||||
return super().finalize_model(model, to_cpu, restore_device)
|
||||
|
||||
|
||||
@overrides
|
||||
def finalize_cell(self, cell:Cell, *args, **kwargs)->CellDesc:
|
||||
# first finalize each node, we will need to recreate node desc with final version
|
||||
node_descs:List[NodeDesc] = []
|
||||
dcell = self._divnas_cells[id(cell)]
|
||||
assert len(cell.dag) == len(list(dcell.node_covs.values()))
|
||||
for node in cell.dag:
|
||||
node_cov = dcell.node_covs[id(node)]
|
||||
node_desc = self.finalize_node(node, cell.desc.max_final_edges, node_cov)
|
||||
node_descs.append(node_desc)
|
||||
|
||||
# (optional) clear out all activation collection information
|
||||
dcell.clear_collect_activations()
|
||||
|
||||
finalized = CellDesc(
|
||||
cell_type=cell.desc.cell_type,
|
||||
id = cell.desc.id,
|
||||
nodes = node_descs,
|
||||
s0_op=cell.s0_op.finalize()[0],
|
||||
s1_op=cell.s1_op.finalize()[0],
|
||||
alphas_from = cell.desc.alphas_from,
|
||||
max_final_edges=cell.desc.max_final_edges,
|
||||
node_ch_out=cell.desc.node_ch_out,
|
||||
post_op=cell.post_op.finalize()[0]
|
||||
)
|
||||
return finalized
|
||||
|
||||
|
||||
@overrides
|
||||
def finalize_node(self, node:nn.ModuleList, max_final_edges:int, cov:np.array, *args, **kwargs)->NodeDesc:
|
||||
# node is a list of edges
|
||||
assert len(node) >= max_final_edges
|
||||
|
||||
# covariance matrix shape must be square 2-D
|
||||
assert len(cov.shape) == 2
|
||||
assert cov.shape[0] == cov.shape[1]
|
||||
|
||||
# the number of primitive operators has to be greater
|
||||
# than equal to the maximum number of final edges
|
||||
# allowed
|
||||
assert cov.shape[0] >= max_final_edges
|
||||
|
||||
# get total number of ops incoming to this node
|
||||
num_ops = sum([edge._op.num_valid_div_ops for edge in node])
|
||||
|
||||
# and collect some bookkeeping indices
|
||||
edge_num_and_op_ind = []
|
||||
for j, edge in enumerate(node):
|
||||
if type(edge._op) == DivOp:
|
||||
for k in range(edge._op.num_valid_div_ops):
|
||||
edge_num_and_op_ind.append((j, k))
|
||||
|
||||
assert len(edge_num_and_op_ind) == num_ops
|
||||
|
||||
# run brute force set selection algorithm
|
||||
max_subset, max_mi = compute_brute_force_sol(cov, max_final_edges)
|
||||
|
||||
# convert the cov indices to edge descs
|
||||
selected_edges = []
|
||||
for ind in max_subset:
|
||||
edge_ind, op_ind = edge_num_and_op_ind[ind]
|
||||
op_desc = node[edge_ind]._op.get_valid_op_desc(op_ind)
|
||||
new_edge = EdgeDesc(op_desc, node[edge_ind].input_ids)
|
||||
selected_edges.append(new_edge)
|
||||
|
||||
# for edge in selected_edges:
|
||||
# self.finalize_edge(edge)
|
||||
|
||||
return NodeDesc(selected_edges)
|
|
@ -0,0 +1,171 @@
|
|||
from typing import Iterable, Optional, Tuple, List
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import numpy as np
|
||||
import h5py
|
||||
|
||||
from overrides import overrides
|
||||
|
||||
from archai.nas.model_desc import OpDesc
|
||||
from archai.nas.operations import Op
|
||||
from archai.common.common import get_conf
|
||||
|
||||
# TODO: reduction cell might have output reduced by 2^1=2X due to
|
||||
# stride 2 through input nodes however FactorizedReduce does only
|
||||
# 4X reduction. Is this correct?
|
||||
|
||||
|
||||
class DivOp(Op):
|
||||
"""The output of DivOp is weighted output of all allowed primitives.
|
||||
"""
|
||||
|
||||
PRIMITIVES = [
|
||||
'max_pool_3x3',
|
||||
'avg_pool_3x3',
|
||||
'skip_connect', # identity
|
||||
'sep_conv_3x3',
|
||||
'sep_conv_5x5',
|
||||
'dil_conv_3x3',
|
||||
'dil_conv_5x5',
|
||||
'none' # this must be at the end so top1 doesn't choose it
|
||||
]
|
||||
|
||||
# list of primitive ops not allowed in the
|
||||
# diversity calculation
|
||||
# NOTALLOWED = ['skip_connect', 'none']
|
||||
NOTALLOWED = ['none']
|
||||
|
||||
def _indices_of_notallowed(self):
|
||||
''' computes indices of notallowed ops in PRIMITIVES '''
|
||||
self._not_allowed_indices = []
|
||||
for op_name in self.NOTALLOWED:
|
||||
self._not_allowed_indices.append(self.PRIMITIVES.index(op_name))
|
||||
self._not_allowed_indices = sorted(self._not_allowed_indices, reverse=True)
|
||||
|
||||
def _create_mapping_valid_to_orig(self):
|
||||
''' Creates a list with indices of the valid ops to the original list '''
|
||||
self._valid_to_orig = []
|
||||
for i, prim in enumerate(self.PRIMITIVES):
|
||||
if prim in self.NOTALLOWED:
|
||||
continue
|
||||
else:
|
||||
self._valid_to_orig.append(i)
|
||||
|
||||
def __init__(self, op_desc:OpDesc, alphas: Iterable[nn.Parameter],
|
||||
affine:bool):
|
||||
super().__init__()
|
||||
|
||||
# assume last PRIMITIVE is 'none'
|
||||
assert DivOp.PRIMITIVES[-1] == 'none'
|
||||
|
||||
conf = get_conf()
|
||||
trainer = conf['nas']['search']['divnas']['archtrainer']
|
||||
finalizer = conf['nas']['search']['finalizer']
|
||||
|
||||
if trainer == 'noalpha' and finalizer == 'default':
|
||||
raise NotImplementedError
|
||||
|
||||
if trainer != 'noalpha':
|
||||
self._set_alphas(alphas)
|
||||
else:
|
||||
self._alphas = None
|
||||
|
||||
self._ops = nn.ModuleList()
|
||||
for primitive in DivOp.PRIMITIVES:
|
||||
op = Op.create(
|
||||
OpDesc(primitive, op_desc.params, in_len=1, trainables=None),
|
||||
affine=affine, alphas=alphas)
|
||||
self._ops.append(op)
|
||||
|
||||
# various state variables for diversity
|
||||
self._collect_activations = False
|
||||
self._forward_counter = 0
|
||||
self._batch_activs = None
|
||||
self._indices_of_notallowed()
|
||||
self._create_mapping_valid_to_orig()
|
||||
|
||||
@property
|
||||
def collect_activations(self)->bool:
|
||||
return self._collect_activations
|
||||
|
||||
@collect_activations.setter
|
||||
def collect_activations(self, to_collect:bool)->None:
|
||||
self._collect_activations = to_collect
|
||||
|
||||
@property
|
||||
def activations(self)->Optional[List[np.array]]:
|
||||
return self._batch_activs
|
||||
|
||||
@property
|
||||
def num_valid_div_ops(self)->int:
|
||||
return len(self.PRIMITIVES) - len(self.NOTALLOWED)
|
||||
|
||||
@overrides
|
||||
def forward(self, x):
|
||||
|
||||
# save activations to object
|
||||
if self._collect_activations:
|
||||
self._forward_counter += 1
|
||||
activs = [op(x) for op in self._ops]
|
||||
self._batch_activs = [t.cpu().detach().numpy() for t in activs]
|
||||
# delete the activations that are not allowed
|
||||
for index in self._not_allowed_indices:
|
||||
del self._batch_activs[index]
|
||||
|
||||
if self._alphas:
|
||||
asm = F.softmax(self._alphas[0], dim=0)
|
||||
result = sum(w * op(x) for w, op in zip(asm, self._ops))
|
||||
else:
|
||||
result = sum(op(x) for op in self._ops)
|
||||
|
||||
return result
|
||||
|
||||
@overrides
|
||||
def alphas(self) -> Iterable[nn.Parameter]:
|
||||
if self._alphas:
|
||||
for alpha in self._alphas:
|
||||
yield alpha
|
||||
|
||||
@overrides
|
||||
def weights(self) -> Iterable[nn.Parameter]:
|
||||
for op in self._ops:
|
||||
for w in op.parameters():
|
||||
yield w
|
||||
|
||||
|
||||
def get_op_desc(self, index:int)->OpDesc:
|
||||
''' index: index in the primitives list '''
|
||||
assert index < len(self.PRIMITIVES)
|
||||
desc, _ = self._ops[index].finalize()
|
||||
return desc
|
||||
|
||||
|
||||
def get_valid_op_desc(self, index:int)->OpDesc:
|
||||
''' index: index in the valid index list '''
|
||||
assert index <= self.num_valid_div_ops
|
||||
orig_index = self._valid_to_orig[index]
|
||||
desc, _ = self._ops[orig_index].finalize()
|
||||
return desc
|
||||
|
||||
|
||||
@overrides
|
||||
def can_drop_path(self) -> bool:
|
||||
return False
|
||||
|
||||
def _set_alphas(self, alphas: Iterable[nn.Parameter]) -> None:
|
||||
# must call before adding other ops
|
||||
assert len(list(self.parameters())) == 0
|
||||
self._alphas = list(alphas)
|
||||
if not len(self._alphas):
|
||||
new_p = nn.Parameter( # TODO: use better init than uniform random?
|
||||
1.0e-3*torch.randn(len(DivOp.PRIMITIVES)), requires_grad=True)
|
||||
# NOTE: This is a way to register parameters with PyTorch.
|
||||
# One creates a dummy variable with the parameters and then
|
||||
# asks back for the parameters in the object from Pytorch
|
||||
# which automagically registers the just created parameters.
|
||||
self._reg_alphas = new_p
|
||||
self._alphas = [p for p in self.parameters()]
|
|
@ -0,0 +1,105 @@
|
|||
import numpy as np
|
||||
from typing import List, Optional, Set, Dict
|
||||
|
||||
from archai.algos.divnas.wmr import Wmr
|
||||
|
||||
|
||||
class SeqOpt:
|
||||
""" Implements SeqOpt
|
||||
TODO: Later on we might want to refactor this class
|
||||
to be able to handle bandit feedback """
|
||||
|
||||
def __init__(self, num_items:int, eps:float):
|
||||
self._num_items = num_items
|
||||
|
||||
# initialize wmr copies
|
||||
self._expert_algos = [Wmr(self._num_items, eps) for i in range(self._num_items)]
|
||||
|
||||
|
||||
def sample_sequence(self, with_replacement=False)->List[int]:
|
||||
|
||||
sel_set = set()
|
||||
# to keep order information
|
||||
sel_list = []
|
||||
|
||||
counter = 0
|
||||
counter_limit = 10000
|
||||
|
||||
for i in range(self._num_items):
|
||||
item_id = self._expert_algos[i].sample()
|
||||
if not with_replacement:
|
||||
# NOTE: this might be an infinite while loop
|
||||
while item_id in sel_set and counter < counter_limit:
|
||||
item_id = self._expert_algos[i].sample()
|
||||
counter += 1
|
||||
|
||||
if counter >= counter_limit:
|
||||
print('Got caught in infinite loop for a while')
|
||||
|
||||
sel_set.add(item_id)
|
||||
sel_list.append(item_id)
|
||||
|
||||
return sel_list
|
||||
|
||||
|
||||
def _check_marg_gains(self, reward_storage:List[List[float]])->bool:
|
||||
reward_array = np.array(reward_storage)
|
||||
|
||||
is_descending = True
|
||||
for i in range(reward_array.shape[1]):
|
||||
marg_gains_this_item = reward_array[:,i]
|
||||
is_descending = np.all(np.diff(marg_gains_this_item)<=0)
|
||||
if not is_descending:
|
||||
return is_descending
|
||||
|
||||
return is_descending
|
||||
|
||||
|
||||
def _scale_minus_one_to_one(self, rewards:np.array)->np.array:
|
||||
scaled = np.interp(rewards, (rewards.min(), rewards.max()), (-1, 1))
|
||||
return scaled
|
||||
|
||||
def update(self, sel_list:List[int], compute_marginal_gain_func)->None:
|
||||
""" In the full information case we will update
|
||||
all expert copies according to the marginal benefits """
|
||||
|
||||
# mother set
|
||||
S = set([i for i in range(self._num_items)])
|
||||
|
||||
reward_storage = []
|
||||
|
||||
# for each slot
|
||||
for slot_id in range(self._num_items):
|
||||
# for each action in the slot
|
||||
sub_sel = set(sel_list[:slot_id])
|
||||
reward_vector = []
|
||||
for item in range(self._num_items):
|
||||
# the function passed in
|
||||
# must already be bound to the
|
||||
# covariance function needed
|
||||
reward = compute_marginal_gain_func(item, sub_sel, S)
|
||||
reward_vector.append(reward)
|
||||
|
||||
# update the expert algo copy for this slot
|
||||
scaled_rewards = self._scale_minus_one_to_one(np.array(reward_vector))
|
||||
self._expert_algos[slot_id].update(scaled_rewards)
|
||||
|
||||
reward_storage.append(reward_vector)
|
||||
|
||||
# # Uncomment to aid in debugging
|
||||
# np.set_printoptions(precision=3, suppress=True)
|
||||
# print('Marginal gain array (item_id X slots)')
|
||||
# print(np.array(reward_storage).T)
|
||||
|
||||
# is_descending = self._check_marg_gains(reward_storage)
|
||||
# if not is_descending:
|
||||
# print('WARNING marginal gains are not diminishing')
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
import numpy as np
|
||||
|
||||
|
||||
|
||||
class Wmr:
|
||||
""" Implements the Randomized Weighted Majority algorithm by Littlestone and Warmuth
|
||||
We use the version in Fig 1 in The Multiplicative Weight Update with the gain version """
|
||||
def __init__(self, num_items:int, eta:float):
|
||||
assert num_items > 0
|
||||
assert eta >= 0.0 and eta <= 0.5
|
||||
self._num_items = num_items
|
||||
self._eta = eta
|
||||
self._weights = self._normalize(np.ones(self._num_items))
|
||||
self._round_counter = 0
|
||||
|
||||
@property
|
||||
def weights(self):
|
||||
return self._weights
|
||||
|
||||
def _normalize(self, weights:np.array)->None:
|
||||
return weights / np.sum(weights)
|
||||
|
||||
def update(self, rewards:np.array)->None:
|
||||
assert len(rewards.shape) == 1
|
||||
assert rewards.shape[0] == self._num_items
|
||||
assert np.all(rewards) >= -1 and np.all(rewards) <= 1.0
|
||||
|
||||
# # annealed learning rate
|
||||
# self._round_counter += 1
|
||||
# eta = self._eta / np.sqrt(self._round_counter)
|
||||
eta = self._eta
|
||||
|
||||
self._weights = self._weights * (1.0 + eta * rewards)
|
||||
self._weights = self._normalize(self._weights)
|
||||
|
||||
|
||||
def sample(self)->int:
|
||||
return np.random.choice(self._num_items, p=self._normalize(self._weights))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
from random import sample
|
||||
from archai.common.utils import AverageMeter
|
||||
from collections import defaultdict
|
||||
from collections import defaultdict, deque
|
||||
from typing import Iterable, Optional, Tuple, List
|
||||
|
||||
import torch
|
||||
|
@ -49,7 +49,6 @@ class GsOp(Op):
|
|||
affine=affine, alphas=alphas)
|
||||
self._ops.append(op)
|
||||
|
||||
|
||||
@overrides
|
||||
def forward(self, x):
|
||||
# soft sample from the categorical distribution
|
||||
|
@ -57,12 +56,11 @@ class GsOp(Op):
|
|||
# TODO: should we be normalizing the ensemble?
|
||||
#sampled = torch.zeros(self._alphas[0].size(), requires_grad=True)
|
||||
sample_storage = []
|
||||
for i in range(self._gs_num_sample):
|
||||
for _ in range(self._gs_num_sample):
|
||||
sampled = F.gumbel_softmax(self._alphas[0], tau=1, hard=False, eps=1e-10, dim=-1)
|
||||
sample_storage.append(sampled)
|
||||
|
||||
samples_summed = torch.sum(torch.stack(sample_storage, dim=0), dim=0)
|
||||
|
||||
return sum(w * op(x) for w, op in zip(samples_summed, self._ops))
|
||||
|
||||
|
||||
|
@ -120,6 +118,13 @@ class GsOp(Op):
|
|||
def can_drop_path(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def get_op_desc(self, index:int)->OpDesc:
|
||||
''' index: index in the primitives list '''
|
||||
assert index < len(self.PRIMITIVES)
|
||||
desc, _ = self._ops[index].finalize()
|
||||
return desc
|
||||
|
||||
|
||||
def _set_alphas(self, alphas: Iterable[nn.Parameter]) -> None:
|
||||
# must call before adding other ops
|
||||
|
|
|
@ -180,6 +180,12 @@ class PetridishOp(Op):
|
|||
# rank=None to indicate no further selection needed as in darts
|
||||
return final_op_desc, None
|
||||
|
||||
def get_op_desc(self, index:int)->OpDesc:
|
||||
''' index: index in the primitives list '''
|
||||
assert index < len(self.PRIMITIVES)
|
||||
desc, _ = self._ops[index].finalize()
|
||||
return desc
|
||||
|
||||
def _set_alphas(self, alphas: Iterable[nn.Parameter], in_len:int) -> None:
|
||||
assert len(list(self.parameters()))==0 # must call before adding other ops
|
||||
|
||||
|
|
|
@ -13,6 +13,8 @@ from archai.common.config import Config
|
|||
from archai.nas import evaluate
|
||||
from archai.nas.search import Search
|
||||
from archai.nas.finalizers import Finalizers
|
||||
from archai.common.common import get_conf
|
||||
from archai.nas.random_finalizers import RandomFinalizers
|
||||
|
||||
|
||||
class ExperimentRunner(ABC, EnforceOverrides):
|
||||
|
@ -93,4 +95,13 @@ class ExperimentRunner(ABC, EnforceOverrides):
|
|||
pass
|
||||
|
||||
def finalizers(self)->Finalizers:
|
||||
return Finalizers()
|
||||
conf = get_conf()
|
||||
finalizer = conf['nas']['search']['finalizer']
|
||||
|
||||
if finalizer == 'default':
|
||||
return Finalizers()
|
||||
elif finalizer == 'random':
|
||||
return RandomFinalizers()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
|
@ -43,7 +43,7 @@ class Finalizers(EnforceOverrides):
|
|||
def finalize_cells(self, model:Model)->List[CellDesc]:
|
||||
return [self.finalize_cell(cell) for cell in model.cells]
|
||||
|
||||
def finalize_cell(self, cell:Cell)->CellDesc:
|
||||
def finalize_cell(self, cell:Cell, *args, **kwargs)->CellDesc:
|
||||
# first finalize each node, we will need to recreate node desc with final version
|
||||
node_descs:List[NodeDesc] = []
|
||||
for node in cell.dag:
|
||||
|
@ -63,7 +63,7 @@ class Finalizers(EnforceOverrides):
|
|||
)
|
||||
return finalized
|
||||
|
||||
def finalize_node(self, node:nn.ModuleList, max_final_edges:int)->NodeDesc:
|
||||
def finalize_node(self, node:nn.ModuleList, max_final_edges:int, *args, **kwargs)->NodeDesc:
|
||||
# get edge ranks, if rank is None it is deemed as required
|
||||
pre_selected, edge_desc_ranks = self.get_edge_ranks(node)
|
||||
ranked_selected = self.select_edges(edge_desc_ranks, max_final_edges)
|
||||
|
|
|
@ -32,7 +32,7 @@ class MacroBuilder(EnforceOverrides):
|
|||
self.model_desc_params = conf_model_desc['params']
|
||||
# endregion
|
||||
|
||||
# this satiesfies N R N R N pattern
|
||||
# this satisfies N R N R N pattern
|
||||
assert self.n_cells >= self.n_reductions * 2 + 1
|
||||
|
||||
# for each reduction, we create one indice
|
||||
|
|
|
@ -82,9 +82,7 @@ class EdgeDesc:
|
|||
->'EdgeDesc':
|
||||
# edge cloning is same as deep copy except that we do it through
|
||||
# constructor for future proofing any additional future rules and
|
||||
# that we allow oveeriding conv_params and clearning weights. This later
|
||||
# bit used in model builder to create edge from template
|
||||
|
||||
# that we allow overiding conv_params and clearing weights
|
||||
e = EdgeDesc(self.op_desc.clone(), self.input_ids)
|
||||
# op_desc should have params set from cloning. If no override supplied
|
||||
# then don't change it
|
||||
|
|
|
@ -8,6 +8,7 @@ from overrides import overrides, EnforceOverrides
|
|||
import torch
|
||||
from torch import affine_grid_generator, nn, Tensor, strided
|
||||
|
||||
|
||||
from ..common import utils, ml_utils
|
||||
from .model_desc import OpDesc, ConvMacroParams
|
||||
|
||||
|
@ -117,6 +118,7 @@ class Op(nn.Module, ABC, EnforceOverrides):
|
|||
def can_drop_path(self)->bool:
|
||||
return True
|
||||
|
||||
|
||||
class PoolBN(Op):
|
||||
"""AvgPool or MaxPool - BN """
|
||||
|
||||
|
@ -288,7 +290,7 @@ class Identity(Op):
|
|||
|
||||
|
||||
class Zero(Op):
|
||||
"""Represents no connection. Zero op can be thought of 1x1 kernal with fixed zero weight.
|
||||
"""Represents no connection. Zero op can be thought of 1x1 kernel with fixed zero weight.
|
||||
For stride=1, it will produce output of same dimension as input but with all 0s. Now with stride of 2, it will zero out every other pixel in output.
|
||||
"""
|
||||
|
||||
|
@ -569,6 +571,7 @@ class MultiOp(Op):
|
|||
in_len=1, trainables=None, children=None)
|
||||
self._ch_adj = Op.create(ch_adj_desc, affine=affine)
|
||||
|
||||
|
||||
@overrides
|
||||
def forward(self, x:Union[Tensor, List[Tensor]])->Tensor:
|
||||
# we may receive k=1..N tensors as inputs. Currently DagEdge will pass
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
from typing import List, Tuple, Optional, Iterator, Dict, Set
|
||||
from overrides import overrides
|
||||
import random
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
import numpy as np
|
||||
|
||||
from archai.common.common import get_conf
|
||||
from archai.common.common import logger
|
||||
from archai.datasets.data import get_data
|
||||
from archai.nas.model import Model
|
||||
from archai.nas.cell import Cell
|
||||
from archai.nas.model_desc import CellDesc, ModelDesc, NodeDesc, EdgeDesc
|
||||
from archai.nas.finalizers import Finalizers
|
||||
from archai.algos.divnas.analyse_activations import compute_brute_force_sol
|
||||
from archai.algos.divnas.divop import DivOp
|
||||
|
||||
|
||||
class RandomFinalizers(Finalizers):
|
||||
|
||||
|
||||
@overrides
|
||||
def finalize_node(self, node:nn.ModuleList, max_final_edges:int)->NodeDesc:
|
||||
# node is a list of edges
|
||||
assert len(node) >= max_final_edges
|
||||
|
||||
# get total number of ops incoming to this node
|
||||
num_ops = 0
|
||||
for edge in node:
|
||||
if hasattr(edge._op, 'PRIMITIVES'):
|
||||
num_ops += len(edge._op.PRIMITIVES)
|
||||
|
||||
# and collect some bookkeeping indices
|
||||
edge_num_and_op_ind = []
|
||||
for j, edge in enumerate(node):
|
||||
if hasattr(edge._op, 'PRIMITIVES'):
|
||||
for k in range(len(edge._op.PRIMITIVES)):
|
||||
edge_num_and_op_ind.append((j, k))
|
||||
|
||||
assert len(edge_num_and_op_ind) == num_ops
|
||||
|
||||
# run random subset selection
|
||||
rand_subset = self._random_subset(num_ops, max_final_edges)
|
||||
|
||||
# convert the cov indices to edge descs
|
||||
selected_edges = []
|
||||
for ind in rand_subset:
|
||||
edge_ind, op_ind = edge_num_and_op_ind[ind]
|
||||
op_desc = node[edge_ind]._op.get_op_desc(op_ind)
|
||||
new_edge = EdgeDesc(op_desc, node[edge_ind].input_ids)
|
||||
selected_edges.append(new_edge)
|
||||
|
||||
return NodeDesc(selected_edges)
|
||||
|
||||
|
||||
def _random_subset(self, num_ops:int, max_final_edges:int)->Set[int]:
|
||||
assert num_ops > 0
|
||||
assert max_final_edges > 0
|
||||
assert max_final_edges <= num_ops
|
||||
|
||||
S = set()
|
||||
while len(S) < max_final_edges:
|
||||
sample = random.randint(0, num_ops)
|
||||
S.add(sample)
|
||||
|
||||
return S
|
|
@ -125,6 +125,7 @@ nas:
|
|||
type: 'CrossEntropyLoss'
|
||||
|
||||
search:
|
||||
finalizer: 'default' # options are 'random' or 'default'
|
||||
data_parallel: False
|
||||
checkpoint:
|
||||
_copy: '/common/checkpoint'
|
||||
|
|
|
@ -0,0 +1,281 @@
|
|||
__include__: "../datasets/cifar10.yaml" # default dataset settings are for cifar
|
||||
|
||||
common:
|
||||
experiment_name: 'throwaway' # you should supply from command line
|
||||
experiment_desc: 'throwaway'
|
||||
logdir: '~/logdir'
|
||||
seed: 2.0
|
||||
tb_enable: False # if True then TensorBoard logging is enabled (may impact perf)
|
||||
tb_dir: '$expdir/tb' # path where tensorboard logs would be stored
|
||||
checkpoint:
|
||||
filename: '$expdir/checkpoint.pth'
|
||||
freq: 10
|
||||
toy_mode: # this section will be used by toy.yaml to setup the toy mode
|
||||
max_batches: 4
|
||||
train_batch: 32
|
||||
test_batch: 64
|
||||
# TODO: workers setting
|
||||
|
||||
# reddis address of Ray cluster. Use None for single node run
|
||||
# otherwise it should something like host:6379. Make sure to run on head node:
|
||||
# "ray start --head --redis-port=6379"
|
||||
redis: null
|
||||
apex: # this is overriden in search and eval individually
|
||||
enabled: False # global switch to disable everything apex
|
||||
distributed_enabled: True # enable/disable distributed mode
|
||||
mixed_prec_enabled: True # switch to disable amp mixed precision
|
||||
gpus: '' # use GPU IDs specified here (comma separated), if '' then use all GPUs
|
||||
opt_level: 'O2' # optimization level for mixed precision
|
||||
bn_fp32: True # keep BN in fp32
|
||||
loss_scale: "dynamic" # loss scaling mode for mixed prec, must be string reprenting floar ot "dynamic"
|
||||
sync_bn: False # should be replace BNs with sync BNs for distributed model
|
||||
scale_lr: True # enable/disable distributed mode
|
||||
min_world_size: 0 # allows to confirm we are indeed in distributed setting
|
||||
detect_anomaly: False # if True, PyTorch code will run 6X slower
|
||||
seed: '_copy: /common/seed'
|
||||
|
||||
smoke_test: False
|
||||
only_eval: False
|
||||
resume: True
|
||||
|
||||
dataset: {} # default dataset settings comes from __include__ on the top
|
||||
|
||||
nas:
|
||||
eval:
|
||||
full_desc_filename: '$expdir/full_model_desc.yaml' # model desc used for building model for evaluation
|
||||
final_desc_filename: '$expdir/final_model_desc.yaml' # model desc used as template to construct cells
|
||||
|
||||
# If below is specified then final_desc_filename is ignored and model is created through factory function instead.
|
||||
# This is useful for running eval for manually designed models such as resnet-50.
|
||||
# The value is string of form 'some.namespace.module.function'. The function returns nn.Module and no required args.
|
||||
final_model_factory: ''
|
||||
|
||||
metric_filename: '$expdir/eval_train_metrics.yaml'
|
||||
model_filename: '$expdir/model.pt' # file to which trained model will be saved
|
||||
data_parallel: False
|
||||
checkpoint:
|
||||
_copy: '/common/checkpoint'
|
||||
resume: '_copy: /common/resume'
|
||||
model_desc:
|
||||
dataset:
|
||||
_copy: '/dataset'
|
||||
max_final_edges: 2 # max edge that can be in final arch per node
|
||||
cell_post_op: 'concate_channels'
|
||||
model_stem0_op: 'stem_conv3x3'
|
||||
model_stem1_op: 'stem_conv3x3'
|
||||
model_post_op: 'pool_adaptive_avg2d'
|
||||
aux_tower_stride: 3 # stride that aux tower should use, 3 is good for 32x32 images, 2 for imagenet
|
||||
stem_multiplier: 3 # output channels multiplier for the stem
|
||||
params: {}
|
||||
|
||||
n_nodes: 4 # number of nodes in a cell
|
||||
n_reductions: 2 # number of reductions to be applied
|
||||
|
||||
init_node_ch: 36 # num of input/output channels for nodes in 1st cell
|
||||
n_cells: 20 # number of cells
|
||||
aux_weight: 0.4 # weight for loss from auxiliary towers in test time arch
|
||||
loader:
|
||||
apex:
|
||||
_copy: '../../trainer/apex'
|
||||
aug: '' # additional augmentations to use
|
||||
cutout: 16 # cutout length, use cutout augmentation when > 0
|
||||
load_train: True # load train split of dataset
|
||||
train_batch: 96
|
||||
train_workers: 4
|
||||
test_workers: '_copy: ../train_workers' # if null then 4
|
||||
load_test: True # load test split of dataset
|
||||
test_batch: 1024
|
||||
val_ratio: 0.0 #split portion for test set, 0 to 1
|
||||
val_fold: 0 #Fold number to use (0 to 4)
|
||||
cv_num: 5 # total number of folds available
|
||||
dataset:
|
||||
_copy: '/dataset'
|
||||
trainer:
|
||||
apex:
|
||||
_copy: '/common/apex'
|
||||
aux_weight: '_copy: /nas/eval/model_desc/aux_weight'
|
||||
drop_path_prob: 0.2 # probability that given edge will be dropped
|
||||
grad_clip: 5.0 # grads above this value is clipped
|
||||
l1_alphas: 0.0 # weight to be applied to sum(abs(alphas)) to loss term
|
||||
logger_freq: 1000 # after every N updates dump loss and other metrics in logger
|
||||
title: 'eval_train'
|
||||
epochs: 600
|
||||
batch_chunks: 1 # split batch into these many chunks and accumulate gradients so we can support GPUs with lower RAM
|
||||
lossfn:
|
||||
type: 'CrossEntropyLoss'
|
||||
optimizer:
|
||||
type: 'sgd'
|
||||
lr: 0.025 # init learning rate
|
||||
decay: 3.0e-4 # pytorch default is 0.0
|
||||
momentum: 0.9 # pytorch default is 0.0
|
||||
nesterov: False # pytorch default is False
|
||||
decay_bn: .NaN # if NaN then same as decay otherwise apply different decay to BN layers
|
||||
lr_schedule:
|
||||
type: 'cosine'
|
||||
min_lr: 0.001 # min learning rate to se bet in eta_min param of scheduler
|
||||
warmup: # increases LR for 0 to current in specified epochs and then hands over to main scheduler
|
||||
multiplier: 1
|
||||
epochs: 0 # 0 disables warmup
|
||||
validation:
|
||||
title: 'eval_test'
|
||||
batch_chunks: '_copy: ../../batch_chunks' # split batch into these many chunks and accumulate gradients so we can support GPUs with lower RAM
|
||||
logger_freq: 0
|
||||
freq: 1 # perform validation only every N epochs
|
||||
lossfn:
|
||||
type: 'CrossEntropyLoss'
|
||||
|
||||
search:
|
||||
finalizer: 'random' # options are mutual information based 'mi' or 'random' or 'default'. NOTE: 'default' is not compatible with 'noalpha' trainer as 'default' uses the darts finalizer and needs alphas
|
||||
divnas:
|
||||
sigma: 168
|
||||
archtrainer: 'bilevel' # options are 'bilevel', 'noalpha'
|
||||
data_parallel: False
|
||||
checkpoint:
|
||||
_copy: '/common/checkpoint'
|
||||
resume: '_copy: /common/resume'
|
||||
search_iters: 1
|
||||
collect_activations: True
|
||||
full_desc_filename: '$expdir/full_model_desc.yaml' # arch before it was finalized
|
||||
final_desc_filename: '$expdir/final_model_desc.yaml' # final arch is saved in this file
|
||||
metrics_dir: '$expdir/models/{reductions}/{cells}/{nodes}/{search_iter}' # where metrics and model stats would be saved from each pareto iteration
|
||||
seed_train:
|
||||
trainer:
|
||||
_copy: '/nas/eval/trainer'
|
||||
title: 'seed_train'
|
||||
epochs: 0 # number of epochs model will be trained before search
|
||||
aux_weight: 0.0
|
||||
drop_path_prob: 0.0
|
||||
loader:
|
||||
_copy: '/nas/eval/loader'
|
||||
train_batch: 128
|
||||
val_ratio: 0.1 #split portion for test set, 0 to 1
|
||||
post_train:
|
||||
trainer:
|
||||
_copy: '/nas/eval/trainer'
|
||||
title: 'post_train'
|
||||
epochs: 0 # number of epochs model will be trained after search
|
||||
aux_weight: 0.0
|
||||
drop_path_prob: 0.0
|
||||
loader:
|
||||
_copy: '/nas/eval/loader'
|
||||
train_batch: 128
|
||||
val_ratio: 0.1 #split portion for test set, 0 to 1
|
||||
pareto:
|
||||
# default parameters are set so there is exactly one search iteration
|
||||
max_cells: 8
|
||||
max_reductions: 2
|
||||
max_nodes: 4
|
||||
enabled: False
|
||||
summary_filename: '$expdir/perito.tsv' # for each iteration of macro, we fave model and perf summary
|
||||
model_desc:
|
||||
# we avoid copying from eval node because dataset settings
|
||||
# may override eval.model_desc with different stems, pool etc
|
||||
dataset:
|
||||
_copy: '/dataset'
|
||||
max_final_edges: 2 # max edge that can be in final arch per node
|
||||
cell_post_op: 'concate_channels'
|
||||
model_stem0_op: 'stem_conv3x3'
|
||||
model_stem1_op: 'stem_conv3x3'
|
||||
model_post_op: 'pool_adaptive_avg2d'
|
||||
aux_tower_stride: 3 # stride that aux tower should use, 3 is good for 32x32 images, 2 for imagenet
|
||||
stem_multiplier: 3 # output channels multiplier for the stem
|
||||
params: {}
|
||||
|
||||
n_nodes: 4 # number of nodes in a cell
|
||||
n_reductions: 2 # number of reductions to be applied
|
||||
|
||||
init_node_ch: 16 # num of input/output channels for nodes in 1st cell
|
||||
n_cells: 8 # number of cells
|
||||
aux_weight: 0.0 # weight for loss from auxiliary towers in test time arch
|
||||
loader:
|
||||
apex:
|
||||
_copy: '../../trainer/apex'
|
||||
aug: '' # additional augmentations to use
|
||||
cutout: 0 # cutout length, use cutout augmentation when > 0
|
||||
load_train: True # load train split of dataset
|
||||
train_batch: 64
|
||||
train_workers: 4 # if null then gpu_count*4
|
||||
test_workers: '_copy: ../train_workers' # if null then 4
|
||||
load_test: False # load test split of dataset
|
||||
test_batch: 1024
|
||||
val_ratio: 0.5 #split portion for test set, 0 to 1
|
||||
val_fold: 0 #Fold number to use (0 to 4)
|
||||
cv_num: 5 # total number of folds available
|
||||
dataset:
|
||||
_copy: '/dataset'
|
||||
trainer:
|
||||
apex:
|
||||
_copy: '/common/apex'
|
||||
aux_weight: '_copy: /nas/search/model_desc/aux_weight'
|
||||
drop_path_prob: 0.0 # probability that given edge will be dropped
|
||||
grad_clip: 5.0 # grads above this value is clipped
|
||||
logger_freq: 1000 # after every N updates dump loss and other metrics in logger
|
||||
title: 'arch_train'
|
||||
epochs: 50
|
||||
batch_chunks: 1 # split batch into these many chunks and accumulate gradients so we can support GPUs with lower RAM
|
||||
# additional vals for the derived class
|
||||
plotsdir: '' #empty string means no plots, other wise plots are generated for each epoch in this dir
|
||||
l1_alphas: 0.0 # weight to be applied to sum(abs(alphas)) to loss term
|
||||
lossfn:
|
||||
type: 'CrossEntropyLoss'
|
||||
optimizer:
|
||||
type: 'sgd'
|
||||
lr: 0.025 # init learning rate
|
||||
decay: 3.0e-4
|
||||
momentum: 0.9 # pytorch default is 0
|
||||
nesterov: False
|
||||
decay_bn: .NaN # if NaN then same as decay otherwise apply different decay to BN layers
|
||||
alpha_optimizer:
|
||||
type: 'adam'
|
||||
lr: 3.0e-4
|
||||
decay: 1.0e-3
|
||||
betas: [0.5, 0.999]
|
||||
decay_bn: .NaN # if NaN then same as decay otherwise apply different decay to BN layers
|
||||
lr_schedule:
|
||||
type: 'cosine'
|
||||
min_lr: 0.001 # min learning rate, this will be used in eta_min param of scheduler
|
||||
warmup: null
|
||||
validation:
|
||||
title: 'search_val'
|
||||
logger_freq: 0
|
||||
batch_chunks: '_copy: ../../batch_chunks' # split batch into these many chunks and accumulate gradients so we can support GPUs with lower RAM
|
||||
freq: 1 # perform validation only every N epochs
|
||||
lossfn:
|
||||
type: 'CrossEntropyLoss'
|
||||
|
||||
|
||||
autoaug:
|
||||
num_op: 2
|
||||
num_policy: 5
|
||||
num_search: 200
|
||||
num_result_per_cv: 10 # after conducting N trials, we will chose the results of top num_result_per_cv
|
||||
loader:
|
||||
apex:
|
||||
_copy: '/common/apex'
|
||||
aug: '' # additional augmentations to use
|
||||
cutout: 16 # cutout length, use cutout augmentation when > 0
|
||||
epochs: 50
|
||||
load_train: True # load train split of dataset
|
||||
train_batch: 64
|
||||
train_workers: 4 # if null then gpu_count*4
|
||||
test_workers: '_copy: ../train_workers' # if null then 4
|
||||
load_test: True # load test split of dataset
|
||||
test_batch: 1024
|
||||
val_ratio: 0.4 #split portion for test set, 0 to 1
|
||||
val_fold: 0 #Fold number to use (0 to 4)
|
||||
cv_num: 5 # total number of folds available
|
||||
dataset:
|
||||
_copy: '/dataset'
|
||||
optimizer:
|
||||
type: 'sgd'
|
||||
lr: 0.025 # init learning rate
|
||||
decay: 3.0e-4 # pytorch default is 0.0
|
||||
momentum: 0.9 # pytorch default is 0.0
|
||||
nesterov: False # pytorch default is False
|
||||
clip: 5.0 # grads above this value is clipped # TODO: Why is this also in trainer?
|
||||
decay_bn: .NaN # if NaN then same as decay otherwise apply different decay to BN layers
|
||||
#betas: [0.9, 0.999] # PyTorch default betas for Adam
|
||||
lr_schedule:
|
||||
type: 'cosine'
|
||||
min_lr: 0.0 # min learning rate, this will be used in eta_min param of scheduler
|
||||
warmup: null
|
|
@ -8,6 +8,7 @@ from archai.algos.random.random_exp_runner import RandomExperimentRunner
|
|||
from archai.algos.manual.manual_exp_runner import ManualExperimentRunner
|
||||
from archai.algos.xnas.xnas_exp_runner import XnasExperimentRunner
|
||||
from archai.algos.gumbelsoftmax.gs_exp_runner import GsExperimentRunner
|
||||
from archai.algos.divnas.divnas_exp_runner import DivnasExperimentRunner
|
||||
|
||||
def main():
|
||||
runner_types:Dict[str, Type[ExperimentRunner]] = {
|
||||
|
@ -16,14 +17,15 @@ def main():
|
|||
'xnas': XnasExperimentRunner,
|
||||
'random': RandomExperimentRunner,
|
||||
'manual': ManualExperimentRunner,
|
||||
'gs': GsExperimentRunner
|
||||
'gs': GsExperimentRunner,
|
||||
'divnas': DivnasExperimentRunner
|
||||
}
|
||||
|
||||
parser = argparse.ArgumentParser(description='NAS E2E Runs')
|
||||
parser.add_argument('--algos', type=str, default='darts,petridish,xnas,random,gs,manual',
|
||||
parser.add_argument('--algos', type=str, default='darts,petridish,xnas,random,gs,divnas,manual',
|
||||
help='NAS algos to run, seperated by comma')
|
||||
parser.add_argument('--datasets', type=str, default='cifar10',
|
||||
help='datasets to use, seperated by comma')
|
||||
help='datasets to use, separated by comma')
|
||||
parser.add_argument('--full', type=lambda x:x.lower()=='true',
|
||||
nargs='?', const=True, default=False,
|
||||
help='Run in full or toy mode just to check for compile errors')
|
||||
|
|
2
setup.py
2
setup.py
|
@ -11,7 +11,7 @@ install_requires=[
|
|||
'hyperopt', # @ git+https://github.com/hyperopt/hyperopt.git
|
||||
'tensorwatch>=0.9.1', 'tensorboard',
|
||||
'pretrainedmodels', 'tqdm', 'sklearn', 'matplotlib', 'psutil',
|
||||
'requests',
|
||||
'requests', 'seaborn',
|
||||
'gorilla', 'pyyaml', 'overrides', 'runstats', 'psutil', 'statopt'
|
||||
]
|
||||
|
||||
|
|
|
@ -0,0 +1,275 @@
|
|||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
from itertools import combinations
|
||||
from copy import deepcopy
|
||||
import math as ma
|
||||
import unittest
|
||||
from tqdm import tqdm
|
||||
from typing import Any, Callable, List, Tuple, Set
|
||||
|
||||
import archai.algos.divnas.analyse_activations as aa
|
||||
from archai.algos.divnas.seqopt import SeqOpt
|
||||
from archai.algos.divnas.analyse_activations import _compute_mi, compute_brute_force_sol
|
||||
from archai.algos.divnas.online_analyse_activations import create_submod_f
|
||||
from archai.algos.divnas.wmr import Wmr
|
||||
|
||||
def create_rbf_func(first:np.array, sigma:float)->Callable:
|
||||
assert len(first.shape) == 1
|
||||
assert sigma >= 0.0
|
||||
def rbf_bound(second:np.array):
|
||||
assert len(second.shape) == 1
|
||||
val = aa.rbf(first, second, sigma)
|
||||
return val
|
||||
return rbf_bound
|
||||
|
||||
|
||||
def synthetic_data2()->List[Tuple[np.array, np.array]]:
|
||||
# num grid locations
|
||||
num_loc = 10
|
||||
# plop some kernels on 0, 3, 9
|
||||
k_0_func = create_rbf_func(np.array([0.0]), 3.0)
|
||||
k_3_func = create_rbf_func(np.array([3.0]), 0.1)
|
||||
k_6_func = create_rbf_func(np.array([6.0]), 0.1)
|
||||
k_9_func = create_rbf_func(np.array([9.0]), 0.5)
|
||||
|
||||
Y = []
|
||||
for i in range(num_loc):
|
||||
i_arr = np.array([i])
|
||||
y = 20.0 * k_0_func(i_arr) - 25.0 * k_3_func(i_arr) + 100.0 * k_6_func(i_arr) - 100.0 * k_9_func(i_arr)
|
||||
y_arr = np.array([y])
|
||||
Y.append((y_arr, i_arr))
|
||||
|
||||
return Y
|
||||
|
||||
|
||||
def synthetic_data()->List[Tuple[np.array, np.array]]:
|
||||
# num grid locations
|
||||
num_loc = 10
|
||||
# plop some kernels on 0, 3, 9
|
||||
k_0_func = create_rbf_func(np.array([0.0]), 3.0)
|
||||
k_3_func = create_rbf_func(np.array([3.0]), 0.1)
|
||||
k_6_func = create_rbf_func(np.array([6.0]), 0.1)
|
||||
k_9_func = create_rbf_func(np.array([9.0]), 0.5)
|
||||
|
||||
Y = []
|
||||
for i in range(num_loc):
|
||||
i_arr = np.array([i])
|
||||
y = -10.0 * k_0_func(i_arr) + 25.0 * k_3_func(i_arr) - 100.0 * k_6_func(i_arr) - 100.0 * k_9_func(i_arr)
|
||||
y_arr = np.array([y])
|
||||
Y.append((y_arr, i_arr))
|
||||
|
||||
return Y
|
||||
|
||||
|
||||
def compute_synthetic_data_covariance(Y:List[Tuple[np.array, np.array]], sigma=0.8):
|
||||
num_obsvs = len(Y)
|
||||
covariance = np.zeros((num_obsvs, num_obsvs), np.float32)
|
||||
|
||||
for i in range(num_obsvs):
|
||||
for j in range(num_obsvs):
|
||||
if i == j:
|
||||
covariance[i][j] = covariance[j][i] = 1.0
|
||||
continue
|
||||
|
||||
obsv_i = Y[i][0]
|
||||
obsv_j = Y[j][0]
|
||||
assert obsv_i.shape == obsv_j.shape
|
||||
if len(obsv_i.shape) == 1:
|
||||
obsv_i = np.reshape(obsv_i, (obsv_i.shape[0], 1))
|
||||
obsv_j = np.reshape(obsv_j, (obsv_j.shape[0], 1))
|
||||
|
||||
rbfs = np.exp(-np.sum(np.square(obsv_i - obsv_j), axis=1) / (2*sigma*sigma))
|
||||
avg_cov = np.sum(rbfs)/obsv_i.shape[0]
|
||||
covariance[i][j] = covariance[j][i] = avg_cov
|
||||
|
||||
return covariance
|
||||
|
||||
|
||||
|
||||
|
||||
class SeqOptSyntheticDataTestCase(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.Y = synthetic_data2()
|
||||
self.vals = [item[0] for item in self.Y]
|
||||
self.cov_kernel = compute_synthetic_data_covariance(self.Y)
|
||||
|
||||
def test_marginal_gain_calculation(self):
|
||||
""" Tests that marginal gain calculation is correct """
|
||||
V = set(range(self.cov_kernel.shape[0]))
|
||||
A_random = set([1])
|
||||
V_minus_A_random = V - A_random
|
||||
y = 2
|
||||
I_A_random = _compute_mi(self.cov_kernel, A_random, V_minus_A_random)
|
||||
A_aug = deepcopy(A_random)
|
||||
A_aug.add(y)
|
||||
V_minus_A_aug = V - A_aug
|
||||
I_A_aug = _compute_mi(self.cov_kernel, A_aug, V_minus_A_aug)
|
||||
diff_via_direct = abs(I_A_aug - I_A_random)
|
||||
print(f'MI(A) {I_A_random}, MI(A U y) {I_A_aug}, diff {diff_via_direct}')
|
||||
|
||||
diff = aa.compute_marginal_gain(y, A_random, V, self.cov_kernel)
|
||||
# the marginal_gain leaves out 0.5 * log term as it does not
|
||||
# matter for ranking elements
|
||||
half_log_diff = 0.5 * np.log(diff)
|
||||
print(f'Diff via aa.compute {half_log_diff}')
|
||||
self.assertAlmostEqual(diff_via_direct, half_log_diff, delta=0.01)
|
||||
|
||||
def test_greedy(self):
|
||||
# budgeted number of sensors
|
||||
budget = 4
|
||||
|
||||
# brute force solution
|
||||
bf_sensors, bf_val = compute_brute_force_sol(self.cov_kernel, budget)
|
||||
print(f'Brute force max subset {bf_sensors}, max mi {bf_val}')
|
||||
|
||||
# greedy
|
||||
greedy_sensors = aa.greedy_op_selection(self.cov_kernel, budget)
|
||||
# find MI of the greedy solution
|
||||
V = set(range(self.cov_kernel.shape[0]))
|
||||
A_greedy = set(greedy_sensors)
|
||||
V_minus_A_greedy = V - A_greedy
|
||||
I_greedy = _compute_mi(self.cov_kernel, A_greedy, V_minus_A_greedy)
|
||||
print(f'Greedy solution is {greedy_sensors}, mi is {I_greedy}')
|
||||
|
||||
self.assertAlmostEqual(bf_val, I_greedy, delta=0.1)
|
||||
|
||||
def test_wmr(self):
|
||||
eta = 0.01
|
||||
num_rounds = 10000
|
||||
gt_distrib = [0.15, 0.5, 0.3, 0.05]
|
||||
num_items = len(gt_distrib)
|
||||
wmr = Wmr(num_items, eta)
|
||||
|
||||
for _ in range(num_rounds):
|
||||
sampled_index = np.random.choice(num_items, p=gt_distrib)
|
||||
rewards = np.zeros((num_items))
|
||||
rewards[sampled_index] = 1.0
|
||||
wmr.update(rewards)
|
||||
|
||||
print(wmr.weights)
|
||||
self.assertTrue(wmr.weights[1] > 0.4)
|
||||
|
||||
def test_seqopt(self):
|
||||
|
||||
# budgeted number of sensors
|
||||
budget = 4
|
||||
|
||||
# brute force solution
|
||||
bf_sensors, bf_val = compute_brute_force_sol(self.cov_kernel, budget)
|
||||
print(f'Brute force max subset {bf_sensors}, max mi {bf_val}')
|
||||
|
||||
# greedy
|
||||
greedy_sensors = aa.greedy_op_selection(self.cov_kernel, budget)
|
||||
# find MI of the greedy solution
|
||||
V = set(range(self.cov_kernel.shape[0]))
|
||||
A_greedy = set(greedy_sensors)
|
||||
V_minus_A_greedy = V - A_greedy
|
||||
I_greedy = _compute_mi(self.cov_kernel, A_greedy, V_minus_A_greedy)
|
||||
print(f'Greedy solution is {greedy_sensors}, mi is {I_greedy}')
|
||||
|
||||
# online greedy
|
||||
eps = 0.1
|
||||
num_items = self.cov_kernel.shape[0]
|
||||
seqopt = SeqOpt(num_items, eps)
|
||||
num_rounds = 100
|
||||
|
||||
for i in tqdm(range(num_rounds)):
|
||||
|
||||
# sample a list of activations from seqopt
|
||||
sel_list = seqopt.sample_sequence(with_replacement=False)
|
||||
|
||||
# NOTE: we are going to use the batch covariance
|
||||
# every round as this is a toy setting and we want to
|
||||
# verify that seqopt is converging to good solutions
|
||||
|
||||
# update seqopt
|
||||
compute_marginal_gain_func = create_submod_f(self.cov_kernel)
|
||||
seqopt.update(sel_list, compute_marginal_gain_func)
|
||||
|
||||
# now sample a list of ops and hope it is diverse
|
||||
seqopt_sensors = seqopt.sample_sequence(with_replacement=False)
|
||||
seqopt_sensors = seqopt_sensors[:budget]
|
||||
|
||||
V = set(range(self.cov_kernel.shape[0]))
|
||||
A_seqopt = set(seqopt_sensors)
|
||||
V_minus_A_seqopt = V - A_seqopt
|
||||
I_seqopt = _compute_mi(self.cov_kernel, A_seqopt, V_minus_A_seqopt)
|
||||
print(f'SeqOpt solution is {seqopt_sensors}, mi is {I_seqopt}')
|
||||
|
||||
self.assertAlmostEqual(I_seqopt, I_greedy, delta=0.1)
|
||||
self.assertAlmostEqual(I_greedy, bf_val, detal=0.1)
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
unittest.main()
|
||||
|
||||
|
||||
# # generate some synthetic 1d data
|
||||
# Y = synthetic_data2()
|
||||
# vals = [item[0] for item in Y]
|
||||
# print(f'{np.unique(vals).shape[0]} unique observations' )
|
||||
# plt.figure()
|
||||
# plt.plot(vals)
|
||||
# # plt.show()
|
||||
|
||||
# # budget on sensor
|
||||
# budget = 4
|
||||
|
||||
# # compute kernel covariance of observations
|
||||
# cov_kernel = compute_synthetic_data_covariance(Y)
|
||||
# print(f'Det of cov_kernel is {np.linalg.det(cov_kernel)}')
|
||||
|
||||
# plt.figure()
|
||||
# sns.heatmap(cov_kernel, annot=False, cmap='coolwarm')
|
||||
# # plt.show()
|
||||
|
||||
# # brute force solution
|
||||
# bf_sensors, bf_val = compute_brute_force_sol(cov_kernel, budget)
|
||||
# print(f'Brute force max subset {bf_sensors}, max mi {bf_val}')
|
||||
|
||||
# # greedy
|
||||
# greedy_sensors = aa.greedy_op_selection(cov_kernel, budget)
|
||||
# # find MI of the greedy solution
|
||||
# V = set(range(cov_kernel.shape[0]))
|
||||
# A_greedy = set(greedy_sensors)
|
||||
# V_minus_A_greedy = V - A_greedy
|
||||
# I_greedy = _compute_mi(cov_kernel, A_greedy, V_minus_A_greedy)
|
||||
# print(f'Greedy solution is {greedy_sensors}, mi is {I_greedy}')
|
||||
|
||||
# # online greedy
|
||||
# eps = 0.1
|
||||
# num_items = cov_kernel.shape[0]
|
||||
# seqopt = SeqOpt(num_items, eps)
|
||||
# num_rounds = 100
|
||||
# for i in range(num_rounds):
|
||||
# print(f'Round {i}/{num_rounds}')
|
||||
|
||||
# # sample a list of activations from seqopt
|
||||
# sel_list = seqopt.sample_sequence(with_replacement=False)
|
||||
|
||||
# # NOTE: we are going to use the batch covariance
|
||||
# # every round as this is a toy setting and we want to
|
||||
# # verify that seqopt is converging to good solutions
|
||||
|
||||
# # update seqopt
|
||||
# compute_marginal_gain_func = create_submod_f(cov_kernel)
|
||||
# seqopt.update(sel_list, compute_marginal_gain_func)
|
||||
|
||||
# # now sample a list of ops and hope it is diverse
|
||||
# seqopt_sensors = seqopt.sample_sequence(with_replacement=False)
|
||||
# V = set(range(cov_kernel.shape[0]))
|
||||
# A_seqopt = set(seqopt_sensors)
|
||||
# V_minus_A_seqopt = V - A_seqopt
|
||||
# I_seqopt = _compute_mi(cov_kernel, A_seqopt, V_minus_A_seqopt)
|
||||
# print(f'SeqOpt solution is {seqopt_sensors}, mi is {I_seqopt}')
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
|
||||
|
Загрузка…
Ссылка в новой задаче