From 27e7d13fd13864842e0148a6bdd8d3f35b0e7324 Mon Sep 17 00:00:00 2001 From: Debadeepta Dey Date: Fri, 3 Dec 2021 16:49:46 -0800 Subject: [PATCH] Zerocost proxies now work for darts architectures but are broken for natsbench architectures. Have to handle both properly. --- .../pruners/measures/fisher.py | 13 ++++++---- .../pruners/measures/grad_norm.py | 12 +++++++--- .../pruners/measures/grasp.py | 24 ++++++++++++++----- .../pruners/measures/jacob_cov.py | 13 +++++++--- .../pruners/measures/plain.py | 13 ++++++---- .../pruners/measures/snip.py | 12 +++++++--- .../pruners/measures/synflow.py | 14 +++++++---- ...t_darts_space_constant_random_evaluator.py | 5 ++-- ...space_constant_random_experiment_runner.py | 5 ++-- 9 files changed, 80 insertions(+), 31 deletions(-) diff --git a/archai/algos/zero_cost_measures/pruners/measures/fisher.py b/archai/algos/zero_cost_measures/pruners/measures/fisher.py index 0e94e5eb..dd699c05 100644 --- a/archai/algos/zero_cost_measures/pruners/measures/fisher.py +++ b/archai/algos/zero_cost_measures/pruners/measures/fisher.py @@ -23,6 +23,8 @@ from typing import Tuple from . import measure from ..p_utils import get_layer_metric_array, reshape_elements +from archai.nas.model import Model + def fisher_forward_conv2d(self, x): x = F.conv2d(x, self.weight, self.bias, self.stride, @@ -86,10 +88,13 @@ def compute_fisher_per_weight(net, inputs, targets, loss_fn, mode, split_data=1) net.zero_grad() outputs = net(inputs[st:en]) - # natsbench sss produces (activation, logits) tuple - if isinstance(outputs, Tuple) and len(outputs) == 2: - outputs = outputs[1] - + # TODO: We have to deal with different output styles of + # different APIs properly + # # natsbench sss produces (activation, logits) tuple + # if isinstance(outputs, Tuple) and len(outputs) == 2: + # outputs = outputs[1] + if isinstance(net, Model): + outputs, aux_logits = outputs[0], outputs[1] loss = loss_fn(outputs, targets[st:en]) loss.backward() diff --git a/archai/algos/zero_cost_measures/pruners/measures/grad_norm.py b/archai/algos/zero_cost_measures/pruners/measures/grad_norm.py index 5014694b..03f48631 100644 --- a/archai/algos/zero_cost_measures/pruners/measures/grad_norm.py +++ b/archai/algos/zero_cost_measures/pruners/measures/grad_norm.py @@ -20,6 +20,8 @@ from typing import Tuple import copy +from archai.nas.model import Model + from . import measure from ..p_utils import get_layer_metric_array @@ -32,9 +34,13 @@ def get_grad_norm_arr(net, inputs, targets, loss_fn, split_data=1, skip_grad=Fal en=(sp+1)*N//split_data outputs = net.forward(inputs[st:en]) - # natsbench sss produces (activation, logits) tuple - if isinstance(outputs, Tuple) and len(outputs) == 2: - outputs = outputs[1] + # TODO: We have to deal with different output styles of + # different APIs properly + # # natsbench sss produces (activation, logits) tuple + # if isinstance(outputs, Tuple) and len(outputs) == 2: + # outputs = outputs[1] + if isinstance(net, Model): + outputs, aux_logits = outputs[0], outputs[1] loss = loss_fn(outputs, targets[st:en]) loss.backward() diff --git a/archai/algos/zero_cost_measures/pruners/measures/grasp.py b/archai/algos/zero_cost_measures/pruners/measures/grasp.py index 1d34b689..2db4e1ac 100644 --- a/archai/algos/zero_cost_measures/pruners/measures/grasp.py +++ b/archai/algos/zero_cost_measures/pruners/measures/grasp.py @@ -23,6 +23,8 @@ from typing import Tuple from . import measure from ..p_utils import get_layer_metric_array +from archai.nas.model import Model + @measure('grasp', bn=True, mode='param') def compute_grasp_per_weight(net, inputs, targets, mode, loss_fn, T=1, num_iters=1, split_data=1): @@ -47,9 +49,15 @@ def compute_grasp_per_weight(net, inputs, targets, mode, loss_fn, T=1, num_iters for _ in range(num_iters): #TODO get new data, otherwise num_iters is useless! outputs = net.forward(inputs[st:en]) - # natsbench sss produces (activation, logits) tuple - if isinstance(outputs, Tuple) and len(outputs) == 2: - outputs = outputs[1] + + # TODO: We have to deal with different output styles of + # different APIs properly + # # natsbench sss produces (activation, logits) tuple + # if isinstance(outputs, Tuple) and len(outputs) == 2: + # outputs = outputs[1] + if isinstance(net, Model): + outputs, aux_logits = outputs[0], outputs[1] + outputs = outputs/T loss = loss_fn(outputs, targets[st:en]) @@ -67,9 +75,13 @@ def compute_grasp_per_weight(net, inputs, targets, mode, loss_fn, T=1, num_iters # forward/grad pass #2 outputs = net.forward(inputs[st:en]) - # natsbench sss produces (activation, logits) tuple - if isinstance(outputs, Tuple) and len(outputs) == 2: - outputs = outputs[1] + # TODO: We have to deal with different output styles of + # different APIs properly + # # natsbench sss produces (activation, logits) tuple + # if isinstance(outputs, Tuple) and len(outputs) == 2: + # outputs = outputs[1] + if isinstance(net, Model): + outputs, aux_logits = outputs[0], outputs[1] outputs = outputs/T loss = loss_fn(outputs, targets[st:en]) diff --git a/archai/algos/zero_cost_measures/pruners/measures/jacob_cov.py b/archai/algos/zero_cost_measures/pruners/measures/jacob_cov.py index aa5fae90..3852a47b 100644 --- a/archai/algos/zero_cost_measures/pruners/measures/jacob_cov.py +++ b/archai/algos/zero_cost_measures/pruners/measures/jacob_cov.py @@ -19,6 +19,7 @@ from typing import Tuple from . import measure +from archai.nas.model import Model def get_batch_jacobian(net, x, target, device, split_data): x.requires_grad_(True) @@ -28,9 +29,15 @@ def get_batch_jacobian(net, x, target, device, split_data): st=sp*N//split_data en=(sp+1)*N//split_data y = net(x[st:en]) - # natsbench sss produces (activation, logits) tuple - if isinstance(y, Tuple): - y = y[1] + + # TODO: We have to deal with different output styles of + # different APIs properly + # # natsbench sss produces (activation, logits) tuple + # if isinstance(outputs, Tuple) and len(outputs) == 2: + # outputs = outputs[1] + if isinstance(net, Model): + y, aux_logits = y[0], y[1] + y.backward(torch.ones_like(y)) jacob = x.grad.detach() diff --git a/archai/algos/zero_cost_measures/pruners/measures/plain.py b/archai/algos/zero_cost_measures/pruners/measures/plain.py index 5a853511..90cb8620 100644 --- a/archai/algos/zero_cost_measures/pruners/measures/plain.py +++ b/archai/algos/zero_cost_measures/pruners/measures/plain.py @@ -21,6 +21,8 @@ from typing import Tuple from . import measure from ..p_utils import get_layer_metric_array +from archai.nas.model import Model + @measure('plain', bn=True, mode='param') def compute_plain_per_weight(net, inputs, targets, mode, loss_fn, split_data=1): @@ -32,10 +34,13 @@ def compute_plain_per_weight(net, inputs, targets, mode, loss_fn, split_data=1): en=(sp+1)*N//split_data outputs = net.forward(inputs[st:en]) - # natsbench sss produces (activation, logits) tuple - if isinstance(outputs, Tuple) and len(outputs) == 2: - outputs = outputs[1] - + # TODO: We have to deal with different output styles of + # different APIs properly + # # natsbench sss produces (activation, logits) tuple + # if isinstance(outputs, Tuple) and len(outputs) == 2: + # outputs = outputs[1] + if isinstance(net, Model): + outputs, aux_logits = outputs[0], outputs[1] loss = loss_fn(outputs, targets[st:en]) loss.backward() diff --git a/archai/algos/zero_cost_measures/pruners/measures/snip.py b/archai/algos/zero_cost_measures/pruners/measures/snip.py index c395c2aa..33ff981d 100644 --- a/archai/algos/zero_cost_measures/pruners/measures/snip.py +++ b/archai/algos/zero_cost_measures/pruners/measures/snip.py @@ -25,6 +25,8 @@ from typing import Tuple from . import measure from ..p_utils import get_layer_metric_array +from archai.nas.model import Model + def snip_forward_conv2d(self, x): return F.conv2d(x, self.weight * self.weight_mask, self.bias, @@ -55,9 +57,13 @@ def compute_snip_per_weight(net, inputs, targets, mode, loss_fn, split_data=1): en=(sp+1)*N//split_data outputs = net.forward(inputs[st:en]) - # natsbench sss produces (activation, logits) tuple - if isinstance(outputs, Tuple) and len(outputs) == 2: - outputs = outputs[1] + # TODO: We have to deal with different output styles of + # different APIs properly + # # natsbench sss produces (activation, logits) tuple + # if isinstance(outputs, Tuple) and len(outputs) == 2: + # outputs = outputs[1] + if isinstance(net, Model): + outputs, aux_logits = outputs[0], outputs[1] loss = loss_fn(outputs, targets[st:en]) loss.backward() diff --git a/archai/algos/zero_cost_measures/pruners/measures/synflow.py b/archai/algos/zero_cost_measures/pruners/measures/synflow.py index 5e15d508..6abbc87f 100644 --- a/archai/algos/zero_cost_measures/pruners/measures/synflow.py +++ b/archai/algos/zero_cost_measures/pruners/measures/synflow.py @@ -20,6 +20,8 @@ from typing import Tuple from . import measure from ..p_utils import get_layer_metric_array +from archai.nas.model import Model + @measure('synflow', bn=False, mode='param') @measure('synflow_bn', bn=True, mode='param') @@ -53,10 +55,14 @@ def compute_synflow_per_weight(net, inputs, targets, mode, split_data=1, loss_fn inputs = torch.ones([1] + input_dim).double().to(device) output = net.forward(inputs) - # natsbench sss produces (activation, logits) tuple - if isinstance(output, Tuple) and len(output) == 2: - output = output[1] - + # TODO: We have to deal with different output styles of + # different APIs properly + # # natsbench sss produces (activation, logits) tuple + # if isinstance(outputs, Tuple) and len(outputs) == 2: + # outputs = outputs[1] + if isinstance(net, Model): + output, aux_logits = output[0], output[1] + torch.sum(output).backward() # select the gradients that we want to use for search/prune diff --git a/archai/algos/zero_cost_measures/zero_cost_darts_space_constant_random_evaluator.py b/archai/algos/zero_cost_measures/zero_cost_darts_space_constant_random_evaluator.py index 2494fff4..a1ca1f5f 100644 --- a/archai/algos/zero_cost_measures/zero_cost_darts_space_constant_random_evaluator.py +++ b/archai/algos/zero_cost_measures/zero_cost_darts_space_constant_random_evaluator.py @@ -37,8 +37,8 @@ class ZeroCostDartsSpaceConstantRandomEvaluator(Evaluater): def create_model(self, conf_eval:Config, model_desc_builder:RandomModelDescBuilder, final_desc_filename=None, full_desc_filename=None)->nn.Module: - assert model_desc_builder is not None, 'DartsSpaceEvaluater requires model_desc_builder' - assert final_desc_filename is None, 'DartsSpaceEvaluater creates its own model desc based on arch index' + assert model_desc_builder is not None, 'ZeroCostDartsSpaceConstantRandomEvaluator requires model_desc_builder' + assert final_desc_filename is None, 'ZeroCostDartsSpaceConstantRandomEvaluator creates its own model desc based on arch index' assert type(model_desc_builder) == RandomModelDescBuilder, 'DartsSpaceEvaluater requires RandomModelDescBuilder' # region conf vars @@ -48,6 +48,7 @@ class ZeroCostDartsSpaceConstantRandomEvaluator(Evaluater): full_desc_filename = conf_eval['full_desc_filename'] conf_model_desc = conf_eval['model_desc'] arch_index = conf_eval['dartsspace']['arch_index'] + self.num_classes = conf_eval['loader']['dataset']['n_classes'] # endregion assert arch_index >= 0 diff --git a/archai/algos/zero_cost_measures/zero_cost_darts_space_constant_random_experiment_runner.py b/archai/algos/zero_cost_measures/zero_cost_darts_space_constant_random_experiment_runner.py index 0eb3b3ea..a5542679 100644 --- a/archai/algos/zero_cost_measures/zero_cost_darts_space_constant_random_experiment_runner.py +++ b/archai/algos/zero_cost_measures/zero_cost_darts_space_constant_random_experiment_runner.py @@ -18,6 +18,7 @@ from archai.nas.evaluater import EvalResult from archai.common.common import get_expdir, logger from archai.algos.proxynas.freeze_manual_searcher import ManualFreezeSearcher from archai.algos.zero_cost_measures.zero_cost_darts_space_constant_random_evaluator import ZeroCostDartsSpaceConstantRandomEvaluator +from archai.algos.random_sample_darts_space.random_model_desc_builder import RandomModelDescBuilder from nats_bench import create @@ -26,8 +27,8 @@ class ZeroCostDartsSpaceConstantRandomExperimentRunner(ExperimentRunner): which are randomly sampled in a reproducible way""" @overrides - def model_desc_builder(self)->Optional[ModelDescBuilder]: - return None + def model_desc_builder(self)->RandomModelDescBuilder: + return RandomModelDescBuilder() @overrides def trainer_class(self)->TArchTrainer: