Zerocost proxies now work for darts architectures but are broken for natsbench architectures. Have to handle both properly.

This commit is contained in:
Debadeepta Dey 2021-12-03 16:49:46 -08:00 коммит произвёл Gustavo Rosa
Родитель 0bb248e20d
Коммит 27e7d13fd1
9 изменённых файлов: 80 добавлений и 31 удалений

Просмотреть файл

@ -23,6 +23,8 @@ from typing import Tuple
from . import measure from . import measure
from ..p_utils import get_layer_metric_array, reshape_elements from ..p_utils import get_layer_metric_array, reshape_elements
from archai.nas.model import Model
def fisher_forward_conv2d(self, x): def fisher_forward_conv2d(self, x):
x = F.conv2d(x, self.weight, self.bias, self.stride, x = F.conv2d(x, self.weight, self.bias, self.stride,
@ -86,10 +88,13 @@ def compute_fisher_per_weight(net, inputs, targets, loss_fn, mode, split_data=1)
net.zero_grad() net.zero_grad()
outputs = net(inputs[st:en]) outputs = net(inputs[st:en])
# natsbench sss produces (activation, logits) tuple # TODO: We have to deal with different output styles of
if isinstance(outputs, Tuple) and len(outputs) == 2: # different APIs properly
outputs = outputs[1] # # natsbench sss produces (activation, logits) tuple
# if isinstance(outputs, Tuple) and len(outputs) == 2:
# outputs = outputs[1]
if isinstance(net, Model):
outputs, aux_logits = outputs[0], outputs[1]
loss = loss_fn(outputs, targets[st:en]) loss = loss_fn(outputs, targets[st:en])
loss.backward() loss.backward()

Просмотреть файл

@ -20,6 +20,8 @@ from typing import Tuple
import copy import copy
from archai.nas.model import Model
from . import measure from . import measure
from ..p_utils import get_layer_metric_array from ..p_utils import get_layer_metric_array
@ -32,9 +34,13 @@ def get_grad_norm_arr(net, inputs, targets, loss_fn, split_data=1, skip_grad=Fal
en=(sp+1)*N//split_data en=(sp+1)*N//split_data
outputs = net.forward(inputs[st:en]) outputs = net.forward(inputs[st:en])
# natsbench sss produces (activation, logits) tuple # TODO: We have to deal with different output styles of
if isinstance(outputs, Tuple) and len(outputs) == 2: # different APIs properly
outputs = outputs[1] # # natsbench sss produces (activation, logits) tuple
# if isinstance(outputs, Tuple) and len(outputs) == 2:
# outputs = outputs[1]
if isinstance(net, Model):
outputs, aux_logits = outputs[0], outputs[1]
loss = loss_fn(outputs, targets[st:en]) loss = loss_fn(outputs, targets[st:en])
loss.backward() loss.backward()

Просмотреть файл

@ -23,6 +23,8 @@ from typing import Tuple
from . import measure from . import measure
from ..p_utils import get_layer_metric_array from ..p_utils import get_layer_metric_array
from archai.nas.model import Model
@measure('grasp', bn=True, mode='param') @measure('grasp', bn=True, mode='param')
def compute_grasp_per_weight(net, inputs, targets, mode, loss_fn, T=1, num_iters=1, split_data=1): def compute_grasp_per_weight(net, inputs, targets, mode, loss_fn, T=1, num_iters=1, split_data=1):
@ -47,9 +49,15 @@ def compute_grasp_per_weight(net, inputs, targets, mode, loss_fn, T=1, num_iters
for _ in range(num_iters): for _ in range(num_iters):
#TODO get new data, otherwise num_iters is useless! #TODO get new data, otherwise num_iters is useless!
outputs = net.forward(inputs[st:en]) outputs = net.forward(inputs[st:en])
# natsbench sss produces (activation, logits) tuple
if isinstance(outputs, Tuple) and len(outputs) == 2: # TODO: We have to deal with different output styles of
outputs = outputs[1] # different APIs properly
# # natsbench sss produces (activation, logits) tuple
# if isinstance(outputs, Tuple) and len(outputs) == 2:
# outputs = outputs[1]
if isinstance(net, Model):
outputs, aux_logits = outputs[0], outputs[1]
outputs = outputs/T outputs = outputs/T
loss = loss_fn(outputs, targets[st:en]) loss = loss_fn(outputs, targets[st:en])
@ -67,9 +75,13 @@ def compute_grasp_per_weight(net, inputs, targets, mode, loss_fn, T=1, num_iters
# forward/grad pass #2 # forward/grad pass #2
outputs = net.forward(inputs[st:en]) outputs = net.forward(inputs[st:en])
# natsbench sss produces (activation, logits) tuple # TODO: We have to deal with different output styles of
if isinstance(outputs, Tuple) and len(outputs) == 2: # different APIs properly
outputs = outputs[1] # # natsbench sss produces (activation, logits) tuple
# if isinstance(outputs, Tuple) and len(outputs) == 2:
# outputs = outputs[1]
if isinstance(net, Model):
outputs, aux_logits = outputs[0], outputs[1]
outputs = outputs/T outputs = outputs/T
loss = loss_fn(outputs, targets[st:en]) loss = loss_fn(outputs, targets[st:en])

Просмотреть файл

@ -19,6 +19,7 @@ from typing import Tuple
from . import measure from . import measure
from archai.nas.model import Model
def get_batch_jacobian(net, x, target, device, split_data): def get_batch_jacobian(net, x, target, device, split_data):
x.requires_grad_(True) x.requires_grad_(True)
@ -28,9 +29,15 @@ def get_batch_jacobian(net, x, target, device, split_data):
st=sp*N//split_data st=sp*N//split_data
en=(sp+1)*N//split_data en=(sp+1)*N//split_data
y = net(x[st:en]) y = net(x[st:en])
# natsbench sss produces (activation, logits) tuple
if isinstance(y, Tuple): # TODO: We have to deal with different output styles of
y = y[1] # different APIs properly
# # natsbench sss produces (activation, logits) tuple
# if isinstance(outputs, Tuple) and len(outputs) == 2:
# outputs = outputs[1]
if isinstance(net, Model):
y, aux_logits = y[0], y[1]
y.backward(torch.ones_like(y)) y.backward(torch.ones_like(y))
jacob = x.grad.detach() jacob = x.grad.detach()

Просмотреть файл

@ -21,6 +21,8 @@ from typing import Tuple
from . import measure from . import measure
from ..p_utils import get_layer_metric_array from ..p_utils import get_layer_metric_array
from archai.nas.model import Model
@measure('plain', bn=True, mode='param') @measure('plain', bn=True, mode='param')
def compute_plain_per_weight(net, inputs, targets, mode, loss_fn, split_data=1): def compute_plain_per_weight(net, inputs, targets, mode, loss_fn, split_data=1):
@ -32,10 +34,13 @@ def compute_plain_per_weight(net, inputs, targets, mode, loss_fn, split_data=1):
en=(sp+1)*N//split_data en=(sp+1)*N//split_data
outputs = net.forward(inputs[st:en]) outputs = net.forward(inputs[st:en])
# natsbench sss produces (activation, logits) tuple # TODO: We have to deal with different output styles of
if isinstance(outputs, Tuple) and len(outputs) == 2: # different APIs properly
outputs = outputs[1] # # natsbench sss produces (activation, logits) tuple
# if isinstance(outputs, Tuple) and len(outputs) == 2:
# outputs = outputs[1]
if isinstance(net, Model):
outputs, aux_logits = outputs[0], outputs[1]
loss = loss_fn(outputs, targets[st:en]) loss = loss_fn(outputs, targets[st:en])
loss.backward() loss.backward()

Просмотреть файл

@ -25,6 +25,8 @@ from typing import Tuple
from . import measure from . import measure
from ..p_utils import get_layer_metric_array from ..p_utils import get_layer_metric_array
from archai.nas.model import Model
def snip_forward_conv2d(self, x): def snip_forward_conv2d(self, x):
return F.conv2d(x, self.weight * self.weight_mask, self.bias, return F.conv2d(x, self.weight * self.weight_mask, self.bias,
@ -55,9 +57,13 @@ def compute_snip_per_weight(net, inputs, targets, mode, loss_fn, split_data=1):
en=(sp+1)*N//split_data en=(sp+1)*N//split_data
outputs = net.forward(inputs[st:en]) outputs = net.forward(inputs[st:en])
# natsbench sss produces (activation, logits) tuple # TODO: We have to deal with different output styles of
if isinstance(outputs, Tuple) and len(outputs) == 2: # different APIs properly
outputs = outputs[1] # # natsbench sss produces (activation, logits) tuple
# if isinstance(outputs, Tuple) and len(outputs) == 2:
# outputs = outputs[1]
if isinstance(net, Model):
outputs, aux_logits = outputs[0], outputs[1]
loss = loss_fn(outputs, targets[st:en]) loss = loss_fn(outputs, targets[st:en])
loss.backward() loss.backward()

Просмотреть файл

@ -20,6 +20,8 @@ from typing import Tuple
from . import measure from . import measure
from ..p_utils import get_layer_metric_array from ..p_utils import get_layer_metric_array
from archai.nas.model import Model
@measure('synflow', bn=False, mode='param') @measure('synflow', bn=False, mode='param')
@measure('synflow_bn', bn=True, mode='param') @measure('synflow_bn', bn=True, mode='param')
@ -53,10 +55,14 @@ def compute_synflow_per_weight(net, inputs, targets, mode, split_data=1, loss_fn
inputs = torch.ones([1] + input_dim).double().to(device) inputs = torch.ones([1] + input_dim).double().to(device)
output = net.forward(inputs) output = net.forward(inputs)
# natsbench sss produces (activation, logits) tuple # TODO: We have to deal with different output styles of
if isinstance(output, Tuple) and len(output) == 2: # different APIs properly
output = output[1] # # natsbench sss produces (activation, logits) tuple
# if isinstance(outputs, Tuple) and len(outputs) == 2:
# outputs = outputs[1]
if isinstance(net, Model):
output, aux_logits = output[0], output[1]
torch.sum(output).backward() torch.sum(output).backward()
# select the gradients that we want to use for search/prune # select the gradients that we want to use for search/prune

Просмотреть файл

@ -37,8 +37,8 @@ class ZeroCostDartsSpaceConstantRandomEvaluator(Evaluater):
def create_model(self, conf_eval:Config, model_desc_builder:RandomModelDescBuilder, def create_model(self, conf_eval:Config, model_desc_builder:RandomModelDescBuilder,
final_desc_filename=None, full_desc_filename=None)->nn.Module: final_desc_filename=None, full_desc_filename=None)->nn.Module:
assert model_desc_builder is not None, 'DartsSpaceEvaluater requires model_desc_builder' assert model_desc_builder is not None, 'ZeroCostDartsSpaceConstantRandomEvaluator requires model_desc_builder'
assert final_desc_filename is None, 'DartsSpaceEvaluater creates its own model desc based on arch index' assert final_desc_filename is None, 'ZeroCostDartsSpaceConstantRandomEvaluator creates its own model desc based on arch index'
assert type(model_desc_builder) == RandomModelDescBuilder, 'DartsSpaceEvaluater requires RandomModelDescBuilder' assert type(model_desc_builder) == RandomModelDescBuilder, 'DartsSpaceEvaluater requires RandomModelDescBuilder'
# region conf vars # region conf vars
@ -48,6 +48,7 @@ class ZeroCostDartsSpaceConstantRandomEvaluator(Evaluater):
full_desc_filename = conf_eval['full_desc_filename'] full_desc_filename = conf_eval['full_desc_filename']
conf_model_desc = conf_eval['model_desc'] conf_model_desc = conf_eval['model_desc']
arch_index = conf_eval['dartsspace']['arch_index'] arch_index = conf_eval['dartsspace']['arch_index']
self.num_classes = conf_eval['loader']['dataset']['n_classes']
# endregion # endregion
assert arch_index >= 0 assert arch_index >= 0

Просмотреть файл

@ -18,6 +18,7 @@ from archai.nas.evaluater import EvalResult
from archai.common.common import get_expdir, logger from archai.common.common import get_expdir, logger
from archai.algos.proxynas.freeze_manual_searcher import ManualFreezeSearcher from archai.algos.proxynas.freeze_manual_searcher import ManualFreezeSearcher
from archai.algos.zero_cost_measures.zero_cost_darts_space_constant_random_evaluator import ZeroCostDartsSpaceConstantRandomEvaluator from archai.algos.zero_cost_measures.zero_cost_darts_space_constant_random_evaluator import ZeroCostDartsSpaceConstantRandomEvaluator
from archai.algos.random_sample_darts_space.random_model_desc_builder import RandomModelDescBuilder
from nats_bench import create from nats_bench import create
@ -26,8 +27,8 @@ class ZeroCostDartsSpaceConstantRandomExperimentRunner(ExperimentRunner):
which are randomly sampled in a reproducible way""" which are randomly sampled in a reproducible way"""
@overrides @overrides
def model_desc_builder(self)->Optional[ModelDescBuilder]: def model_desc_builder(self)->RandomModelDescBuilder:
return None return RandomModelDescBuilder()
@overrides @overrides
def trainer_class(self)->TArchTrainer: def trainer_class(self)->TArchTrainer: