зеркало из https://github.com/microsoft/archai.git
Zerocost proxies now work for darts architectures but are broken for natsbench architectures. Have to handle both properly.
This commit is contained in:
Родитель
0bb248e20d
Коммит
27e7d13fd1
|
@ -23,6 +23,8 @@ from typing import Tuple
|
||||||
from . import measure
|
from . import measure
|
||||||
from ..p_utils import get_layer_metric_array, reshape_elements
|
from ..p_utils import get_layer_metric_array, reshape_elements
|
||||||
|
|
||||||
|
from archai.nas.model import Model
|
||||||
|
|
||||||
|
|
||||||
def fisher_forward_conv2d(self, x):
|
def fisher_forward_conv2d(self, x):
|
||||||
x = F.conv2d(x, self.weight, self.bias, self.stride,
|
x = F.conv2d(x, self.weight, self.bias, self.stride,
|
||||||
|
@ -86,10 +88,13 @@ def compute_fisher_per_weight(net, inputs, targets, loss_fn, mode, split_data=1)
|
||||||
|
|
||||||
net.zero_grad()
|
net.zero_grad()
|
||||||
outputs = net(inputs[st:en])
|
outputs = net(inputs[st:en])
|
||||||
# natsbench sss produces (activation, logits) tuple
|
# TODO: We have to deal with different output styles of
|
||||||
if isinstance(outputs, Tuple) and len(outputs) == 2:
|
# different APIs properly
|
||||||
outputs = outputs[1]
|
# # natsbench sss produces (activation, logits) tuple
|
||||||
|
# if isinstance(outputs, Tuple) and len(outputs) == 2:
|
||||||
|
# outputs = outputs[1]
|
||||||
|
if isinstance(net, Model):
|
||||||
|
outputs, aux_logits = outputs[0], outputs[1]
|
||||||
loss = loss_fn(outputs, targets[st:en])
|
loss = loss_fn(outputs, targets[st:en])
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,8 @@ from typing import Tuple
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
|
from archai.nas.model import Model
|
||||||
|
|
||||||
from . import measure
|
from . import measure
|
||||||
from ..p_utils import get_layer_metric_array
|
from ..p_utils import get_layer_metric_array
|
||||||
|
|
||||||
|
@ -32,9 +34,13 @@ def get_grad_norm_arr(net, inputs, targets, loss_fn, split_data=1, skip_grad=Fal
|
||||||
en=(sp+1)*N//split_data
|
en=(sp+1)*N//split_data
|
||||||
|
|
||||||
outputs = net.forward(inputs[st:en])
|
outputs = net.forward(inputs[st:en])
|
||||||
# natsbench sss produces (activation, logits) tuple
|
# TODO: We have to deal with different output styles of
|
||||||
if isinstance(outputs, Tuple) and len(outputs) == 2:
|
# different APIs properly
|
||||||
outputs = outputs[1]
|
# # natsbench sss produces (activation, logits) tuple
|
||||||
|
# if isinstance(outputs, Tuple) and len(outputs) == 2:
|
||||||
|
# outputs = outputs[1]
|
||||||
|
if isinstance(net, Model):
|
||||||
|
outputs, aux_logits = outputs[0], outputs[1]
|
||||||
loss = loss_fn(outputs, targets[st:en])
|
loss = loss_fn(outputs, targets[st:en])
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,8 @@ from typing import Tuple
|
||||||
from . import measure
|
from . import measure
|
||||||
from ..p_utils import get_layer_metric_array
|
from ..p_utils import get_layer_metric_array
|
||||||
|
|
||||||
|
from archai.nas.model import Model
|
||||||
|
|
||||||
|
|
||||||
@measure('grasp', bn=True, mode='param')
|
@measure('grasp', bn=True, mode='param')
|
||||||
def compute_grasp_per_weight(net, inputs, targets, mode, loss_fn, T=1, num_iters=1, split_data=1):
|
def compute_grasp_per_weight(net, inputs, targets, mode, loss_fn, T=1, num_iters=1, split_data=1):
|
||||||
|
@ -47,9 +49,15 @@ def compute_grasp_per_weight(net, inputs, targets, mode, loss_fn, T=1, num_iters
|
||||||
for _ in range(num_iters):
|
for _ in range(num_iters):
|
||||||
#TODO get new data, otherwise num_iters is useless!
|
#TODO get new data, otherwise num_iters is useless!
|
||||||
outputs = net.forward(inputs[st:en])
|
outputs = net.forward(inputs[st:en])
|
||||||
# natsbench sss produces (activation, logits) tuple
|
|
||||||
if isinstance(outputs, Tuple) and len(outputs) == 2:
|
# TODO: We have to deal with different output styles of
|
||||||
outputs = outputs[1]
|
# different APIs properly
|
||||||
|
# # natsbench sss produces (activation, logits) tuple
|
||||||
|
# if isinstance(outputs, Tuple) and len(outputs) == 2:
|
||||||
|
# outputs = outputs[1]
|
||||||
|
if isinstance(net, Model):
|
||||||
|
outputs, aux_logits = outputs[0], outputs[1]
|
||||||
|
|
||||||
outputs = outputs/T
|
outputs = outputs/T
|
||||||
|
|
||||||
loss = loss_fn(outputs, targets[st:en])
|
loss = loss_fn(outputs, targets[st:en])
|
||||||
|
@ -67,9 +75,13 @@ def compute_grasp_per_weight(net, inputs, targets, mode, loss_fn, T=1, num_iters
|
||||||
|
|
||||||
# forward/grad pass #2
|
# forward/grad pass #2
|
||||||
outputs = net.forward(inputs[st:en])
|
outputs = net.forward(inputs[st:en])
|
||||||
# natsbench sss produces (activation, logits) tuple
|
# TODO: We have to deal with different output styles of
|
||||||
if isinstance(outputs, Tuple) and len(outputs) == 2:
|
# different APIs properly
|
||||||
outputs = outputs[1]
|
# # natsbench sss produces (activation, logits) tuple
|
||||||
|
# if isinstance(outputs, Tuple) and len(outputs) == 2:
|
||||||
|
# outputs = outputs[1]
|
||||||
|
if isinstance(net, Model):
|
||||||
|
outputs, aux_logits = outputs[0], outputs[1]
|
||||||
outputs = outputs/T
|
outputs = outputs/T
|
||||||
|
|
||||||
loss = loss_fn(outputs, targets[st:en])
|
loss = loss_fn(outputs, targets[st:en])
|
||||||
|
|
|
@ -19,6 +19,7 @@ from typing import Tuple
|
||||||
|
|
||||||
from . import measure
|
from . import measure
|
||||||
|
|
||||||
|
from archai.nas.model import Model
|
||||||
|
|
||||||
def get_batch_jacobian(net, x, target, device, split_data):
|
def get_batch_jacobian(net, x, target, device, split_data):
|
||||||
x.requires_grad_(True)
|
x.requires_grad_(True)
|
||||||
|
@ -28,9 +29,15 @@ def get_batch_jacobian(net, x, target, device, split_data):
|
||||||
st=sp*N//split_data
|
st=sp*N//split_data
|
||||||
en=(sp+1)*N//split_data
|
en=(sp+1)*N//split_data
|
||||||
y = net(x[st:en])
|
y = net(x[st:en])
|
||||||
# natsbench sss produces (activation, logits) tuple
|
|
||||||
if isinstance(y, Tuple):
|
# TODO: We have to deal with different output styles of
|
||||||
y = y[1]
|
# different APIs properly
|
||||||
|
# # natsbench sss produces (activation, logits) tuple
|
||||||
|
# if isinstance(outputs, Tuple) and len(outputs) == 2:
|
||||||
|
# outputs = outputs[1]
|
||||||
|
if isinstance(net, Model):
|
||||||
|
y, aux_logits = y[0], y[1]
|
||||||
|
|
||||||
y.backward(torch.ones_like(y))
|
y.backward(torch.ones_like(y))
|
||||||
|
|
||||||
jacob = x.grad.detach()
|
jacob = x.grad.detach()
|
||||||
|
|
|
@ -21,6 +21,8 @@ from typing import Tuple
|
||||||
from . import measure
|
from . import measure
|
||||||
from ..p_utils import get_layer_metric_array
|
from ..p_utils import get_layer_metric_array
|
||||||
|
|
||||||
|
from archai.nas.model import Model
|
||||||
|
|
||||||
|
|
||||||
@measure('plain', bn=True, mode='param')
|
@measure('plain', bn=True, mode='param')
|
||||||
def compute_plain_per_weight(net, inputs, targets, mode, loss_fn, split_data=1):
|
def compute_plain_per_weight(net, inputs, targets, mode, loss_fn, split_data=1):
|
||||||
|
@ -32,10 +34,13 @@ def compute_plain_per_weight(net, inputs, targets, mode, loss_fn, split_data=1):
|
||||||
en=(sp+1)*N//split_data
|
en=(sp+1)*N//split_data
|
||||||
|
|
||||||
outputs = net.forward(inputs[st:en])
|
outputs = net.forward(inputs[st:en])
|
||||||
# natsbench sss produces (activation, logits) tuple
|
# TODO: We have to deal with different output styles of
|
||||||
if isinstance(outputs, Tuple) and len(outputs) == 2:
|
# different APIs properly
|
||||||
outputs = outputs[1]
|
# # natsbench sss produces (activation, logits) tuple
|
||||||
|
# if isinstance(outputs, Tuple) and len(outputs) == 2:
|
||||||
|
# outputs = outputs[1]
|
||||||
|
if isinstance(net, Model):
|
||||||
|
outputs, aux_logits = outputs[0], outputs[1]
|
||||||
loss = loss_fn(outputs, targets[st:en])
|
loss = loss_fn(outputs, targets[st:en])
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
|
|
@ -25,6 +25,8 @@ from typing import Tuple
|
||||||
from . import measure
|
from . import measure
|
||||||
from ..p_utils import get_layer_metric_array
|
from ..p_utils import get_layer_metric_array
|
||||||
|
|
||||||
|
from archai.nas.model import Model
|
||||||
|
|
||||||
|
|
||||||
def snip_forward_conv2d(self, x):
|
def snip_forward_conv2d(self, x):
|
||||||
return F.conv2d(x, self.weight * self.weight_mask, self.bias,
|
return F.conv2d(x, self.weight * self.weight_mask, self.bias,
|
||||||
|
@ -55,9 +57,13 @@ def compute_snip_per_weight(net, inputs, targets, mode, loss_fn, split_data=1):
|
||||||
en=(sp+1)*N//split_data
|
en=(sp+1)*N//split_data
|
||||||
|
|
||||||
outputs = net.forward(inputs[st:en])
|
outputs = net.forward(inputs[st:en])
|
||||||
# natsbench sss produces (activation, logits) tuple
|
# TODO: We have to deal with different output styles of
|
||||||
if isinstance(outputs, Tuple) and len(outputs) == 2:
|
# different APIs properly
|
||||||
outputs = outputs[1]
|
# # natsbench sss produces (activation, logits) tuple
|
||||||
|
# if isinstance(outputs, Tuple) and len(outputs) == 2:
|
||||||
|
# outputs = outputs[1]
|
||||||
|
if isinstance(net, Model):
|
||||||
|
outputs, aux_logits = outputs[0], outputs[1]
|
||||||
loss = loss_fn(outputs, targets[st:en])
|
loss = loss_fn(outputs, targets[st:en])
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,8 @@ from typing import Tuple
|
||||||
from . import measure
|
from . import measure
|
||||||
from ..p_utils import get_layer_metric_array
|
from ..p_utils import get_layer_metric_array
|
||||||
|
|
||||||
|
from archai.nas.model import Model
|
||||||
|
|
||||||
|
|
||||||
@measure('synflow', bn=False, mode='param')
|
@measure('synflow', bn=False, mode='param')
|
||||||
@measure('synflow_bn', bn=True, mode='param')
|
@measure('synflow_bn', bn=True, mode='param')
|
||||||
|
@ -53,10 +55,14 @@ def compute_synflow_per_weight(net, inputs, targets, mode, split_data=1, loss_fn
|
||||||
inputs = torch.ones([1] + input_dim).double().to(device)
|
inputs = torch.ones([1] + input_dim).double().to(device)
|
||||||
output = net.forward(inputs)
|
output = net.forward(inputs)
|
||||||
|
|
||||||
# natsbench sss produces (activation, logits) tuple
|
# TODO: We have to deal with different output styles of
|
||||||
if isinstance(output, Tuple) and len(output) == 2:
|
# different APIs properly
|
||||||
output = output[1]
|
# # natsbench sss produces (activation, logits) tuple
|
||||||
|
# if isinstance(outputs, Tuple) and len(outputs) == 2:
|
||||||
|
# outputs = outputs[1]
|
||||||
|
if isinstance(net, Model):
|
||||||
|
output, aux_logits = output[0], output[1]
|
||||||
|
|
||||||
torch.sum(output).backward()
|
torch.sum(output).backward()
|
||||||
|
|
||||||
# select the gradients that we want to use for search/prune
|
# select the gradients that we want to use for search/prune
|
||||||
|
|
|
@ -37,8 +37,8 @@ class ZeroCostDartsSpaceConstantRandomEvaluator(Evaluater):
|
||||||
def create_model(self, conf_eval:Config, model_desc_builder:RandomModelDescBuilder,
|
def create_model(self, conf_eval:Config, model_desc_builder:RandomModelDescBuilder,
|
||||||
final_desc_filename=None, full_desc_filename=None)->nn.Module:
|
final_desc_filename=None, full_desc_filename=None)->nn.Module:
|
||||||
|
|
||||||
assert model_desc_builder is not None, 'DartsSpaceEvaluater requires model_desc_builder'
|
assert model_desc_builder is not None, 'ZeroCostDartsSpaceConstantRandomEvaluator requires model_desc_builder'
|
||||||
assert final_desc_filename is None, 'DartsSpaceEvaluater creates its own model desc based on arch index'
|
assert final_desc_filename is None, 'ZeroCostDartsSpaceConstantRandomEvaluator creates its own model desc based on arch index'
|
||||||
assert type(model_desc_builder) == RandomModelDescBuilder, 'DartsSpaceEvaluater requires RandomModelDescBuilder'
|
assert type(model_desc_builder) == RandomModelDescBuilder, 'DartsSpaceEvaluater requires RandomModelDescBuilder'
|
||||||
|
|
||||||
# region conf vars
|
# region conf vars
|
||||||
|
@ -48,6 +48,7 @@ class ZeroCostDartsSpaceConstantRandomEvaluator(Evaluater):
|
||||||
full_desc_filename = conf_eval['full_desc_filename']
|
full_desc_filename = conf_eval['full_desc_filename']
|
||||||
conf_model_desc = conf_eval['model_desc']
|
conf_model_desc = conf_eval['model_desc']
|
||||||
arch_index = conf_eval['dartsspace']['arch_index']
|
arch_index = conf_eval['dartsspace']['arch_index']
|
||||||
|
self.num_classes = conf_eval['loader']['dataset']['n_classes']
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
assert arch_index >= 0
|
assert arch_index >= 0
|
||||||
|
|
|
@ -18,6 +18,7 @@ from archai.nas.evaluater import EvalResult
|
||||||
from archai.common.common import get_expdir, logger
|
from archai.common.common import get_expdir, logger
|
||||||
from archai.algos.proxynas.freeze_manual_searcher import ManualFreezeSearcher
|
from archai.algos.proxynas.freeze_manual_searcher import ManualFreezeSearcher
|
||||||
from archai.algos.zero_cost_measures.zero_cost_darts_space_constant_random_evaluator import ZeroCostDartsSpaceConstantRandomEvaluator
|
from archai.algos.zero_cost_measures.zero_cost_darts_space_constant_random_evaluator import ZeroCostDartsSpaceConstantRandomEvaluator
|
||||||
|
from archai.algos.random_sample_darts_space.random_model_desc_builder import RandomModelDescBuilder
|
||||||
|
|
||||||
from nats_bench import create
|
from nats_bench import create
|
||||||
|
|
||||||
|
@ -26,8 +27,8 @@ class ZeroCostDartsSpaceConstantRandomExperimentRunner(ExperimentRunner):
|
||||||
which are randomly sampled in a reproducible way"""
|
which are randomly sampled in a reproducible way"""
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def model_desc_builder(self)->Optional[ModelDescBuilder]:
|
def model_desc_builder(self)->RandomModelDescBuilder:
|
||||||
return None
|
return RandomModelDescBuilder()
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def trainer_class(self)->TArchTrainer:
|
def trainer_class(self)->TArchTrainer:
|
||||||
|
|
Загрузка…
Ссылка в новой задаче