Zerocost proxies now work for darts architectures but are broken for natsbench architectures. Have to handle both properly.

This commit is contained in:
Debadeepta Dey 2021-12-03 16:49:46 -08:00 коммит произвёл Gustavo Rosa
Родитель 0bb248e20d
Коммит 27e7d13fd1
9 изменённых файлов: 80 добавлений и 31 удалений

Просмотреть файл

@ -23,6 +23,8 @@ from typing import Tuple
from . import measure
from ..p_utils import get_layer_metric_array, reshape_elements
from archai.nas.model import Model
def fisher_forward_conv2d(self, x):
x = F.conv2d(x, self.weight, self.bias, self.stride,
@ -86,10 +88,13 @@ def compute_fisher_per_weight(net, inputs, targets, loss_fn, mode, split_data=1)
net.zero_grad()
outputs = net(inputs[st:en])
# natsbench sss produces (activation, logits) tuple
if isinstance(outputs, Tuple) and len(outputs) == 2:
outputs = outputs[1]
# TODO: We have to deal with different output styles of
# different APIs properly
# # natsbench sss produces (activation, logits) tuple
# if isinstance(outputs, Tuple) and len(outputs) == 2:
# outputs = outputs[1]
if isinstance(net, Model):
outputs, aux_logits = outputs[0], outputs[1]
loss = loss_fn(outputs, targets[st:en])
loss.backward()

Просмотреть файл

@ -20,6 +20,8 @@ from typing import Tuple
import copy
from archai.nas.model import Model
from . import measure
from ..p_utils import get_layer_metric_array
@ -32,9 +34,13 @@ def get_grad_norm_arr(net, inputs, targets, loss_fn, split_data=1, skip_grad=Fal
en=(sp+1)*N//split_data
outputs = net.forward(inputs[st:en])
# natsbench sss produces (activation, logits) tuple
if isinstance(outputs, Tuple) and len(outputs) == 2:
outputs = outputs[1]
# TODO: We have to deal with different output styles of
# different APIs properly
# # natsbench sss produces (activation, logits) tuple
# if isinstance(outputs, Tuple) and len(outputs) == 2:
# outputs = outputs[1]
if isinstance(net, Model):
outputs, aux_logits = outputs[0], outputs[1]
loss = loss_fn(outputs, targets[st:en])
loss.backward()

Просмотреть файл

@ -23,6 +23,8 @@ from typing import Tuple
from . import measure
from ..p_utils import get_layer_metric_array
from archai.nas.model import Model
@measure('grasp', bn=True, mode='param')
def compute_grasp_per_weight(net, inputs, targets, mode, loss_fn, T=1, num_iters=1, split_data=1):
@ -47,9 +49,15 @@ def compute_grasp_per_weight(net, inputs, targets, mode, loss_fn, T=1, num_iters
for _ in range(num_iters):
#TODO get new data, otherwise num_iters is useless!
outputs = net.forward(inputs[st:en])
# natsbench sss produces (activation, logits) tuple
if isinstance(outputs, Tuple) and len(outputs) == 2:
outputs = outputs[1]
# TODO: We have to deal with different output styles of
# different APIs properly
# # natsbench sss produces (activation, logits) tuple
# if isinstance(outputs, Tuple) and len(outputs) == 2:
# outputs = outputs[1]
if isinstance(net, Model):
outputs, aux_logits = outputs[0], outputs[1]
outputs = outputs/T
loss = loss_fn(outputs, targets[st:en])
@ -67,9 +75,13 @@ def compute_grasp_per_weight(net, inputs, targets, mode, loss_fn, T=1, num_iters
# forward/grad pass #2
outputs = net.forward(inputs[st:en])
# natsbench sss produces (activation, logits) tuple
if isinstance(outputs, Tuple) and len(outputs) == 2:
outputs = outputs[1]
# TODO: We have to deal with different output styles of
# different APIs properly
# # natsbench sss produces (activation, logits) tuple
# if isinstance(outputs, Tuple) and len(outputs) == 2:
# outputs = outputs[1]
if isinstance(net, Model):
outputs, aux_logits = outputs[0], outputs[1]
outputs = outputs/T
loss = loss_fn(outputs, targets[st:en])

Просмотреть файл

@ -19,6 +19,7 @@ from typing import Tuple
from . import measure
from archai.nas.model import Model
def get_batch_jacobian(net, x, target, device, split_data):
x.requires_grad_(True)
@ -28,9 +29,15 @@ def get_batch_jacobian(net, x, target, device, split_data):
st=sp*N//split_data
en=(sp+1)*N//split_data
y = net(x[st:en])
# natsbench sss produces (activation, logits) tuple
if isinstance(y, Tuple):
y = y[1]
# TODO: We have to deal with different output styles of
# different APIs properly
# # natsbench sss produces (activation, logits) tuple
# if isinstance(outputs, Tuple) and len(outputs) == 2:
# outputs = outputs[1]
if isinstance(net, Model):
y, aux_logits = y[0], y[1]
y.backward(torch.ones_like(y))
jacob = x.grad.detach()

Просмотреть файл

@ -21,6 +21,8 @@ from typing import Tuple
from . import measure
from ..p_utils import get_layer_metric_array
from archai.nas.model import Model
@measure('plain', bn=True, mode='param')
def compute_plain_per_weight(net, inputs, targets, mode, loss_fn, split_data=1):
@ -32,10 +34,13 @@ def compute_plain_per_weight(net, inputs, targets, mode, loss_fn, split_data=1):
en=(sp+1)*N//split_data
outputs = net.forward(inputs[st:en])
# natsbench sss produces (activation, logits) tuple
if isinstance(outputs, Tuple) and len(outputs) == 2:
outputs = outputs[1]
# TODO: We have to deal with different output styles of
# different APIs properly
# # natsbench sss produces (activation, logits) tuple
# if isinstance(outputs, Tuple) and len(outputs) == 2:
# outputs = outputs[1]
if isinstance(net, Model):
outputs, aux_logits = outputs[0], outputs[1]
loss = loss_fn(outputs, targets[st:en])
loss.backward()

Просмотреть файл

@ -25,6 +25,8 @@ from typing import Tuple
from . import measure
from ..p_utils import get_layer_metric_array
from archai.nas.model import Model
def snip_forward_conv2d(self, x):
return F.conv2d(x, self.weight * self.weight_mask, self.bias,
@ -55,9 +57,13 @@ def compute_snip_per_weight(net, inputs, targets, mode, loss_fn, split_data=1):
en=(sp+1)*N//split_data
outputs = net.forward(inputs[st:en])
# natsbench sss produces (activation, logits) tuple
if isinstance(outputs, Tuple) and len(outputs) == 2:
outputs = outputs[1]
# TODO: We have to deal with different output styles of
# different APIs properly
# # natsbench sss produces (activation, logits) tuple
# if isinstance(outputs, Tuple) and len(outputs) == 2:
# outputs = outputs[1]
if isinstance(net, Model):
outputs, aux_logits = outputs[0], outputs[1]
loss = loss_fn(outputs, targets[st:en])
loss.backward()

Просмотреть файл

@ -20,6 +20,8 @@ from typing import Tuple
from . import measure
from ..p_utils import get_layer_metric_array
from archai.nas.model import Model
@measure('synflow', bn=False, mode='param')
@measure('synflow_bn', bn=True, mode='param')
@ -53,10 +55,14 @@ def compute_synflow_per_weight(net, inputs, targets, mode, split_data=1, loss_fn
inputs = torch.ones([1] + input_dim).double().to(device)
output = net.forward(inputs)
# natsbench sss produces (activation, logits) tuple
if isinstance(output, Tuple) and len(output) == 2:
output = output[1]
# TODO: We have to deal with different output styles of
# different APIs properly
# # natsbench sss produces (activation, logits) tuple
# if isinstance(outputs, Tuple) and len(outputs) == 2:
# outputs = outputs[1]
if isinstance(net, Model):
output, aux_logits = output[0], output[1]
torch.sum(output).backward()
# select the gradients that we want to use for search/prune

Просмотреть файл

@ -37,8 +37,8 @@ class ZeroCostDartsSpaceConstantRandomEvaluator(Evaluater):
def create_model(self, conf_eval:Config, model_desc_builder:RandomModelDescBuilder,
final_desc_filename=None, full_desc_filename=None)->nn.Module:
assert model_desc_builder is not None, 'DartsSpaceEvaluater requires model_desc_builder'
assert final_desc_filename is None, 'DartsSpaceEvaluater creates its own model desc based on arch index'
assert model_desc_builder is not None, 'ZeroCostDartsSpaceConstantRandomEvaluator requires model_desc_builder'
assert final_desc_filename is None, 'ZeroCostDartsSpaceConstantRandomEvaluator creates its own model desc based on arch index'
assert type(model_desc_builder) == RandomModelDescBuilder, 'DartsSpaceEvaluater requires RandomModelDescBuilder'
# region conf vars
@ -48,6 +48,7 @@ class ZeroCostDartsSpaceConstantRandomEvaluator(Evaluater):
full_desc_filename = conf_eval['full_desc_filename']
conf_model_desc = conf_eval['model_desc']
arch_index = conf_eval['dartsspace']['arch_index']
self.num_classes = conf_eval['loader']['dataset']['n_classes']
# endregion
assert arch_index >= 0

Просмотреть файл

@ -18,6 +18,7 @@ from archai.nas.evaluater import EvalResult
from archai.common.common import get_expdir, logger
from archai.algos.proxynas.freeze_manual_searcher import ManualFreezeSearcher
from archai.algos.zero_cost_measures.zero_cost_darts_space_constant_random_evaluator import ZeroCostDartsSpaceConstantRandomEvaluator
from archai.algos.random_sample_darts_space.random_model_desc_builder import RandomModelDescBuilder
from nats_bench import create
@ -26,8 +27,8 @@ class ZeroCostDartsSpaceConstantRandomExperimentRunner(ExperimentRunner):
which are randomly sampled in a reproducible way"""
@overrides
def model_desc_builder(self)->Optional[ModelDescBuilder]:
return None
def model_desc_builder(self)->RandomModelDescBuilder:
return RandomModelDescBuilder()
@overrides
def trainer_class(self)->TArchTrainer: