зеркало из https://github.com/microsoft/archai.git
Zerocost proxies now work for darts architectures but are broken for natsbench architectures. Have to handle both properly.
This commit is contained in:
Родитель
0bb248e20d
Коммит
27e7d13fd1
|
@ -23,6 +23,8 @@ from typing import Tuple
|
|||
from . import measure
|
||||
from ..p_utils import get_layer_metric_array, reshape_elements
|
||||
|
||||
from archai.nas.model import Model
|
||||
|
||||
|
||||
def fisher_forward_conv2d(self, x):
|
||||
x = F.conv2d(x, self.weight, self.bias, self.stride,
|
||||
|
@ -86,10 +88,13 @@ def compute_fisher_per_weight(net, inputs, targets, loss_fn, mode, split_data=1)
|
|||
|
||||
net.zero_grad()
|
||||
outputs = net(inputs[st:en])
|
||||
# natsbench sss produces (activation, logits) tuple
|
||||
if isinstance(outputs, Tuple) and len(outputs) == 2:
|
||||
outputs = outputs[1]
|
||||
|
||||
# TODO: We have to deal with different output styles of
|
||||
# different APIs properly
|
||||
# # natsbench sss produces (activation, logits) tuple
|
||||
# if isinstance(outputs, Tuple) and len(outputs) == 2:
|
||||
# outputs = outputs[1]
|
||||
if isinstance(net, Model):
|
||||
outputs, aux_logits = outputs[0], outputs[1]
|
||||
loss = loss_fn(outputs, targets[st:en])
|
||||
loss.backward()
|
||||
|
||||
|
|
|
@ -20,6 +20,8 @@ from typing import Tuple
|
|||
|
||||
import copy
|
||||
|
||||
from archai.nas.model import Model
|
||||
|
||||
from . import measure
|
||||
from ..p_utils import get_layer_metric_array
|
||||
|
||||
|
@ -32,9 +34,13 @@ def get_grad_norm_arr(net, inputs, targets, loss_fn, split_data=1, skip_grad=Fal
|
|||
en=(sp+1)*N//split_data
|
||||
|
||||
outputs = net.forward(inputs[st:en])
|
||||
# natsbench sss produces (activation, logits) tuple
|
||||
if isinstance(outputs, Tuple) and len(outputs) == 2:
|
||||
outputs = outputs[1]
|
||||
# TODO: We have to deal with different output styles of
|
||||
# different APIs properly
|
||||
# # natsbench sss produces (activation, logits) tuple
|
||||
# if isinstance(outputs, Tuple) and len(outputs) == 2:
|
||||
# outputs = outputs[1]
|
||||
if isinstance(net, Model):
|
||||
outputs, aux_logits = outputs[0], outputs[1]
|
||||
loss = loss_fn(outputs, targets[st:en])
|
||||
loss.backward()
|
||||
|
||||
|
|
|
@ -23,6 +23,8 @@ from typing import Tuple
|
|||
from . import measure
|
||||
from ..p_utils import get_layer_metric_array
|
||||
|
||||
from archai.nas.model import Model
|
||||
|
||||
|
||||
@measure('grasp', bn=True, mode='param')
|
||||
def compute_grasp_per_weight(net, inputs, targets, mode, loss_fn, T=1, num_iters=1, split_data=1):
|
||||
|
@ -47,9 +49,15 @@ def compute_grasp_per_weight(net, inputs, targets, mode, loss_fn, T=1, num_iters
|
|||
for _ in range(num_iters):
|
||||
#TODO get new data, otherwise num_iters is useless!
|
||||
outputs = net.forward(inputs[st:en])
|
||||
# natsbench sss produces (activation, logits) tuple
|
||||
if isinstance(outputs, Tuple) and len(outputs) == 2:
|
||||
outputs = outputs[1]
|
||||
|
||||
# TODO: We have to deal with different output styles of
|
||||
# different APIs properly
|
||||
# # natsbench sss produces (activation, logits) tuple
|
||||
# if isinstance(outputs, Tuple) and len(outputs) == 2:
|
||||
# outputs = outputs[1]
|
||||
if isinstance(net, Model):
|
||||
outputs, aux_logits = outputs[0], outputs[1]
|
||||
|
||||
outputs = outputs/T
|
||||
|
||||
loss = loss_fn(outputs, targets[st:en])
|
||||
|
@ -67,9 +75,13 @@ def compute_grasp_per_weight(net, inputs, targets, mode, loss_fn, T=1, num_iters
|
|||
|
||||
# forward/grad pass #2
|
||||
outputs = net.forward(inputs[st:en])
|
||||
# natsbench sss produces (activation, logits) tuple
|
||||
if isinstance(outputs, Tuple) and len(outputs) == 2:
|
||||
outputs = outputs[1]
|
||||
# TODO: We have to deal with different output styles of
|
||||
# different APIs properly
|
||||
# # natsbench sss produces (activation, logits) tuple
|
||||
# if isinstance(outputs, Tuple) and len(outputs) == 2:
|
||||
# outputs = outputs[1]
|
||||
if isinstance(net, Model):
|
||||
outputs, aux_logits = outputs[0], outputs[1]
|
||||
outputs = outputs/T
|
||||
|
||||
loss = loss_fn(outputs, targets[st:en])
|
||||
|
|
|
@ -19,6 +19,7 @@ from typing import Tuple
|
|||
|
||||
from . import measure
|
||||
|
||||
from archai.nas.model import Model
|
||||
|
||||
def get_batch_jacobian(net, x, target, device, split_data):
|
||||
x.requires_grad_(True)
|
||||
|
@ -28,9 +29,15 @@ def get_batch_jacobian(net, x, target, device, split_data):
|
|||
st=sp*N//split_data
|
||||
en=(sp+1)*N//split_data
|
||||
y = net(x[st:en])
|
||||
# natsbench sss produces (activation, logits) tuple
|
||||
if isinstance(y, Tuple):
|
||||
y = y[1]
|
||||
|
||||
# TODO: We have to deal with different output styles of
|
||||
# different APIs properly
|
||||
# # natsbench sss produces (activation, logits) tuple
|
||||
# if isinstance(outputs, Tuple) and len(outputs) == 2:
|
||||
# outputs = outputs[1]
|
||||
if isinstance(net, Model):
|
||||
y, aux_logits = y[0], y[1]
|
||||
|
||||
y.backward(torch.ones_like(y))
|
||||
|
||||
jacob = x.grad.detach()
|
||||
|
|
|
@ -21,6 +21,8 @@ from typing import Tuple
|
|||
from . import measure
|
||||
from ..p_utils import get_layer_metric_array
|
||||
|
||||
from archai.nas.model import Model
|
||||
|
||||
|
||||
@measure('plain', bn=True, mode='param')
|
||||
def compute_plain_per_weight(net, inputs, targets, mode, loss_fn, split_data=1):
|
||||
|
@ -32,10 +34,13 @@ def compute_plain_per_weight(net, inputs, targets, mode, loss_fn, split_data=1):
|
|||
en=(sp+1)*N//split_data
|
||||
|
||||
outputs = net.forward(inputs[st:en])
|
||||
# natsbench sss produces (activation, logits) tuple
|
||||
if isinstance(outputs, Tuple) and len(outputs) == 2:
|
||||
outputs = outputs[1]
|
||||
|
||||
# TODO: We have to deal with different output styles of
|
||||
# different APIs properly
|
||||
# # natsbench sss produces (activation, logits) tuple
|
||||
# if isinstance(outputs, Tuple) and len(outputs) == 2:
|
||||
# outputs = outputs[1]
|
||||
if isinstance(net, Model):
|
||||
outputs, aux_logits = outputs[0], outputs[1]
|
||||
loss = loss_fn(outputs, targets[st:en])
|
||||
loss.backward()
|
||||
|
||||
|
|
|
@ -25,6 +25,8 @@ from typing import Tuple
|
|||
from . import measure
|
||||
from ..p_utils import get_layer_metric_array
|
||||
|
||||
from archai.nas.model import Model
|
||||
|
||||
|
||||
def snip_forward_conv2d(self, x):
|
||||
return F.conv2d(x, self.weight * self.weight_mask, self.bias,
|
||||
|
@ -55,9 +57,13 @@ def compute_snip_per_weight(net, inputs, targets, mode, loss_fn, split_data=1):
|
|||
en=(sp+1)*N//split_data
|
||||
|
||||
outputs = net.forward(inputs[st:en])
|
||||
# natsbench sss produces (activation, logits) tuple
|
||||
if isinstance(outputs, Tuple) and len(outputs) == 2:
|
||||
outputs = outputs[1]
|
||||
# TODO: We have to deal with different output styles of
|
||||
# different APIs properly
|
||||
# # natsbench sss produces (activation, logits) tuple
|
||||
# if isinstance(outputs, Tuple) and len(outputs) == 2:
|
||||
# outputs = outputs[1]
|
||||
if isinstance(net, Model):
|
||||
outputs, aux_logits = outputs[0], outputs[1]
|
||||
loss = loss_fn(outputs, targets[st:en])
|
||||
loss.backward()
|
||||
|
||||
|
|
|
@ -20,6 +20,8 @@ from typing import Tuple
|
|||
from . import measure
|
||||
from ..p_utils import get_layer_metric_array
|
||||
|
||||
from archai.nas.model import Model
|
||||
|
||||
|
||||
@measure('synflow', bn=False, mode='param')
|
||||
@measure('synflow_bn', bn=True, mode='param')
|
||||
|
@ -53,10 +55,14 @@ def compute_synflow_per_weight(net, inputs, targets, mode, split_data=1, loss_fn
|
|||
inputs = torch.ones([1] + input_dim).double().to(device)
|
||||
output = net.forward(inputs)
|
||||
|
||||
# natsbench sss produces (activation, logits) tuple
|
||||
if isinstance(output, Tuple) and len(output) == 2:
|
||||
output = output[1]
|
||||
|
||||
# TODO: We have to deal with different output styles of
|
||||
# different APIs properly
|
||||
# # natsbench sss produces (activation, logits) tuple
|
||||
# if isinstance(outputs, Tuple) and len(outputs) == 2:
|
||||
# outputs = outputs[1]
|
||||
if isinstance(net, Model):
|
||||
output, aux_logits = output[0], output[1]
|
||||
|
||||
torch.sum(output).backward()
|
||||
|
||||
# select the gradients that we want to use for search/prune
|
||||
|
|
|
@ -37,8 +37,8 @@ class ZeroCostDartsSpaceConstantRandomEvaluator(Evaluater):
|
|||
def create_model(self, conf_eval:Config, model_desc_builder:RandomModelDescBuilder,
|
||||
final_desc_filename=None, full_desc_filename=None)->nn.Module:
|
||||
|
||||
assert model_desc_builder is not None, 'DartsSpaceEvaluater requires model_desc_builder'
|
||||
assert final_desc_filename is None, 'DartsSpaceEvaluater creates its own model desc based on arch index'
|
||||
assert model_desc_builder is not None, 'ZeroCostDartsSpaceConstantRandomEvaluator requires model_desc_builder'
|
||||
assert final_desc_filename is None, 'ZeroCostDartsSpaceConstantRandomEvaluator creates its own model desc based on arch index'
|
||||
assert type(model_desc_builder) == RandomModelDescBuilder, 'DartsSpaceEvaluater requires RandomModelDescBuilder'
|
||||
|
||||
# region conf vars
|
||||
|
@ -48,6 +48,7 @@ class ZeroCostDartsSpaceConstantRandomEvaluator(Evaluater):
|
|||
full_desc_filename = conf_eval['full_desc_filename']
|
||||
conf_model_desc = conf_eval['model_desc']
|
||||
arch_index = conf_eval['dartsspace']['arch_index']
|
||||
self.num_classes = conf_eval['loader']['dataset']['n_classes']
|
||||
# endregion
|
||||
|
||||
assert arch_index >= 0
|
||||
|
|
|
@ -18,6 +18,7 @@ from archai.nas.evaluater import EvalResult
|
|||
from archai.common.common import get_expdir, logger
|
||||
from archai.algos.proxynas.freeze_manual_searcher import ManualFreezeSearcher
|
||||
from archai.algos.zero_cost_measures.zero_cost_darts_space_constant_random_evaluator import ZeroCostDartsSpaceConstantRandomEvaluator
|
||||
from archai.algos.random_sample_darts_space.random_model_desc_builder import RandomModelDescBuilder
|
||||
|
||||
from nats_bench import create
|
||||
|
||||
|
@ -26,8 +27,8 @@ class ZeroCostDartsSpaceConstantRandomExperimentRunner(ExperimentRunner):
|
|||
which are randomly sampled in a reproducible way"""
|
||||
|
||||
@overrides
|
||||
def model_desc_builder(self)->Optional[ModelDescBuilder]:
|
||||
return None
|
||||
def model_desc_builder(self)->RandomModelDescBuilder:
|
||||
return RandomModelDescBuilder()
|
||||
|
||||
@overrides
|
||||
def trainer_class(self)->TArchTrainer:
|
||||
|
|
Загрузка…
Ссылка в новой задаче