Made zerocost measures code resilient to darts, natsbench tss and sss outputs.

This commit is contained in:
Debadeepta Dey 2021-12-03 17:34:28 -08:00 коммит произвёл Gustavo Rosa
Родитель 27e7d13fd1
Коммит 586dde98a8
8 изменённых файлов: 32 добавлений и 37 удалений

2
.vscode/launch.json поставляемый
Просмотреть файл

@ -350,7 +350,7 @@
"request": "launch",
"program": "${cwd}/scripts/main.py",
"console": "integratedTerminal",
"args": ["--full", "--algos", "zerocost_natsbench_space", "--datasets", "synthetic_cifar10"]
"args": ["--full", "--algos", "zerocost_natsbench_space", "--datasets", "cifar10"]
},
{
"name": "ZeroCost-Natsbench-Space-Toy",

Просмотреть файл

@ -25,6 +25,8 @@ from ..p_utils import get_layer_metric_array, reshape_elements
from archai.nas.model import Model
from archai.algos.natsbench.lib.models.shape_infers.InferTinyCellNet import DynamicShapeTinyNet
def fisher_forward_conv2d(self, x):
x = F.conv2d(x, self.weight, self.bias, self.stride,
@ -88,11 +90,9 @@ def compute_fisher_per_weight(net, inputs, targets, loss_fn, mode, split_data=1)
net.zero_grad()
outputs = net(inputs[st:en])
# TODO: We have to deal with different output styles of
# different APIs properly
# # natsbench sss produces (activation, logits) tuple
# if isinstance(outputs, Tuple) and len(outputs) == 2:
# outputs = outputs[1]
# natsbench sss produces (activation, logits) tuple
if isinstance(outputs, DynamicShapeTinyNet) and len(outputs) == 2:
outputs = outputs[1]
if isinstance(net, Model):
outputs, aux_logits = outputs[0], outputs[1]
loss = loss_fn(outputs, targets[st:en])

Просмотреть файл

@ -21,6 +21,7 @@ from typing import Tuple
import copy
from archai.nas.model import Model
from archai.algos.natsbench.lib.models.shape_infers.InferTinyCellNet import DynamicShapeTinyNet
from . import measure
from ..p_utils import get_layer_metric_array
@ -34,11 +35,9 @@ def get_grad_norm_arr(net, inputs, targets, loss_fn, split_data=1, skip_grad=Fal
en=(sp+1)*N//split_data
outputs = net.forward(inputs[st:en])
# TODO: We have to deal with different output styles of
# different APIs properly
# # natsbench sss produces (activation, logits) tuple
# if isinstance(outputs, Tuple) and len(outputs) == 2:
# outputs = outputs[1]
if isinstance(outputs, DynamicShapeTinyNet) and len(outputs) == 2:
outputs = outputs[1]
if isinstance(net, Model):
outputs, aux_logits = outputs[0], outputs[1]
loss = loss_fn(outputs, targets[st:en])

Просмотреть файл

@ -25,6 +25,8 @@ from ..p_utils import get_layer_metric_array
from archai.nas.model import Model
from archai.algos.natsbench.lib.models.shape_infers.InferTinyCellNet import DynamicShapeTinyNet
@measure('grasp', bn=True, mode='param')
def compute_grasp_per_weight(net, inputs, targets, mode, loss_fn, T=1, num_iters=1, split_data=1):
@ -50,11 +52,9 @@ def compute_grasp_per_weight(net, inputs, targets, mode, loss_fn, T=1, num_iters
#TODO get new data, otherwise num_iters is useless!
outputs = net.forward(inputs[st:en])
# TODO: We have to deal with different output styles of
# different APIs properly
# # natsbench sss produces (activation, logits) tuple
# if isinstance(outputs, Tuple) and len(outputs) == 2:
# outputs = outputs[1]
# natsbench sss produces (activation, logits) tuple
if isinstance(outputs, DynamicShapeTinyNet) and len(outputs) == 2:
outputs = outputs[1]
if isinstance(net, Model):
outputs, aux_logits = outputs[0], outputs[1]

Просмотреть файл

@ -20,6 +20,7 @@ from typing import Tuple
from . import measure
from archai.nas.model import Model
from archai.algos.natsbench.lib.models.shape_infers.InferTinyCellNet import DynamicShapeTinyNet
def get_batch_jacobian(net, x, target, device, split_data):
x.requires_grad_(True)
@ -30,11 +31,9 @@ def get_batch_jacobian(net, x, target, device, split_data):
en=(sp+1)*N//split_data
y = net(x[st:en])
# TODO: We have to deal with different output styles of
# different APIs properly
# # natsbench sss produces (activation, logits) tuple
# if isinstance(outputs, Tuple) and len(outputs) == 2:
# outputs = outputs[1]
# natsbench sss produces (activation, logits) tuple
if isinstance(y, DynamicShapeTinyNet) and len(y) == 2:
y = y[1]
if isinstance(net, Model):
y, aux_logits = y[0], y[1]

Просмотреть файл

@ -22,6 +22,7 @@ from . import measure
from ..p_utils import get_layer_metric_array
from archai.nas.model import Model
from archai.algos.natsbench.lib.models.shape_infers.InferTinyCellNet import DynamicShapeTinyNet
@measure('plain', bn=True, mode='param')
@ -34,11 +35,10 @@ def compute_plain_per_weight(net, inputs, targets, mode, loss_fn, split_data=1):
en=(sp+1)*N//split_data
outputs = net.forward(inputs[st:en])
# TODO: We have to deal with different output styles of
# different APIs properly
# # natsbench sss produces (activation, logits) tuple
# if isinstance(outputs, Tuple) and len(outputs) == 2:
# outputs = outputs[1]
# natsbench sss produces (activation, logits) tuple
if isinstance(outputs, DynamicShapeTinyNet) and len(outputs) == 2:
outputs = outputs[1]
if isinstance(net, Model):
outputs, aux_logits = outputs[0], outputs[1]
loss = loss_fn(outputs, targets[st:en])

Просмотреть файл

@ -26,7 +26,7 @@ from . import measure
from ..p_utils import get_layer_metric_array
from archai.nas.model import Model
from archai.algos.natsbench.lib.models.shape_infers.InferTinyCellNet import DynamicShapeTinyNet
def snip_forward_conv2d(self, x):
return F.conv2d(x, self.weight * self.weight_mask, self.bias,
@ -57,11 +57,10 @@ def compute_snip_per_weight(net, inputs, targets, mode, loss_fn, split_data=1):
en=(sp+1)*N//split_data
outputs = net.forward(inputs[st:en])
# TODO: We have to deal with different output styles of
# different APIs properly
# # natsbench sss produces (activation, logits) tuple
# if isinstance(outputs, Tuple) and len(outputs) == 2:
# outputs = outputs[1]
# natsbench sss produces (activation, logits) tuple
if isinstance(outputs, DynamicShapeTinyNet) and len(outputs) == 2:
outputs = outputs[1]
if isinstance(net, Model):
outputs, aux_logits = outputs[0], outputs[1]
loss = loss_fn(outputs, targets[st:en])

Просмотреть файл

@ -21,7 +21,7 @@ from . import measure
from ..p_utils import get_layer_metric_array
from archai.nas.model import Model
from archai.algos.natsbench.lib.models.shape_infers.InferTinyCellNet import DynamicShapeTinyNet
@measure('synflow', bn=False, mode='param')
@measure('synflow_bn', bn=True, mode='param')
@ -55,11 +55,9 @@ def compute_synflow_per_weight(net, inputs, targets, mode, split_data=1, loss_fn
inputs = torch.ones([1] + input_dim).double().to(device)
output = net.forward(inputs)
# TODO: We have to deal with different output styles of
# different APIs properly
# # natsbench sss produces (activation, logits) tuple
# if isinstance(outputs, Tuple) and len(outputs) == 2:
# outputs = outputs[1]
# natsbench sss produces (activation, logits) tuple
if isinstance(output, DynamicShapeTinyNet) and len(output) == 2:
output = output[1]
if isinstance(net, Model):
output, aux_logits = output[0], output[1]