зеркало из https://github.com/microsoft/archai.git
Made zerocost measures code resilient to darts, natsbench tss and sss outputs.
This commit is contained in:
Родитель
27e7d13fd1
Коммит
586dde98a8
|
@ -350,7 +350,7 @@
|
|||
"request": "launch",
|
||||
"program": "${cwd}/scripts/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["--full", "--algos", "zerocost_natsbench_space", "--datasets", "synthetic_cifar10"]
|
||||
"args": ["--full", "--algos", "zerocost_natsbench_space", "--datasets", "cifar10"]
|
||||
},
|
||||
{
|
||||
"name": "ZeroCost-Natsbench-Space-Toy",
|
||||
|
|
|
@ -25,6 +25,8 @@ from ..p_utils import get_layer_metric_array, reshape_elements
|
|||
|
||||
from archai.nas.model import Model
|
||||
|
||||
from archai.algos.natsbench.lib.models.shape_infers.InferTinyCellNet import DynamicShapeTinyNet
|
||||
|
||||
|
||||
def fisher_forward_conv2d(self, x):
|
||||
x = F.conv2d(x, self.weight, self.bias, self.stride,
|
||||
|
@ -88,11 +90,9 @@ def compute_fisher_per_weight(net, inputs, targets, loss_fn, mode, split_data=1)
|
|||
|
||||
net.zero_grad()
|
||||
outputs = net(inputs[st:en])
|
||||
# TODO: We have to deal with different output styles of
|
||||
# different APIs properly
|
||||
# # natsbench sss produces (activation, logits) tuple
|
||||
# if isinstance(outputs, Tuple) and len(outputs) == 2:
|
||||
# outputs = outputs[1]
|
||||
# natsbench sss produces (activation, logits) tuple
|
||||
if isinstance(outputs, DynamicShapeTinyNet) and len(outputs) == 2:
|
||||
outputs = outputs[1]
|
||||
if isinstance(net, Model):
|
||||
outputs, aux_logits = outputs[0], outputs[1]
|
||||
loss = loss_fn(outputs, targets[st:en])
|
||||
|
|
|
@ -21,6 +21,7 @@ from typing import Tuple
|
|||
import copy
|
||||
|
||||
from archai.nas.model import Model
|
||||
from archai.algos.natsbench.lib.models.shape_infers.InferTinyCellNet import DynamicShapeTinyNet
|
||||
|
||||
from . import measure
|
||||
from ..p_utils import get_layer_metric_array
|
||||
|
@ -34,11 +35,9 @@ def get_grad_norm_arr(net, inputs, targets, loss_fn, split_data=1, skip_grad=Fal
|
|||
en=(sp+1)*N//split_data
|
||||
|
||||
outputs = net.forward(inputs[st:en])
|
||||
# TODO: We have to deal with different output styles of
|
||||
# different APIs properly
|
||||
# # natsbench sss produces (activation, logits) tuple
|
||||
# if isinstance(outputs, Tuple) and len(outputs) == 2:
|
||||
# outputs = outputs[1]
|
||||
if isinstance(outputs, DynamicShapeTinyNet) and len(outputs) == 2:
|
||||
outputs = outputs[1]
|
||||
if isinstance(net, Model):
|
||||
outputs, aux_logits = outputs[0], outputs[1]
|
||||
loss = loss_fn(outputs, targets[st:en])
|
||||
|
|
|
@ -25,6 +25,8 @@ from ..p_utils import get_layer_metric_array
|
|||
|
||||
from archai.nas.model import Model
|
||||
|
||||
from archai.algos.natsbench.lib.models.shape_infers.InferTinyCellNet import DynamicShapeTinyNet
|
||||
|
||||
|
||||
@measure('grasp', bn=True, mode='param')
|
||||
def compute_grasp_per_weight(net, inputs, targets, mode, loss_fn, T=1, num_iters=1, split_data=1):
|
||||
|
@ -50,11 +52,9 @@ def compute_grasp_per_weight(net, inputs, targets, mode, loss_fn, T=1, num_iters
|
|||
#TODO get new data, otherwise num_iters is useless!
|
||||
outputs = net.forward(inputs[st:en])
|
||||
|
||||
# TODO: We have to deal with different output styles of
|
||||
# different APIs properly
|
||||
# # natsbench sss produces (activation, logits) tuple
|
||||
# if isinstance(outputs, Tuple) and len(outputs) == 2:
|
||||
# outputs = outputs[1]
|
||||
# natsbench sss produces (activation, logits) tuple
|
||||
if isinstance(outputs, DynamicShapeTinyNet) and len(outputs) == 2:
|
||||
outputs = outputs[1]
|
||||
if isinstance(net, Model):
|
||||
outputs, aux_logits = outputs[0], outputs[1]
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ from typing import Tuple
|
|||
from . import measure
|
||||
|
||||
from archai.nas.model import Model
|
||||
from archai.algos.natsbench.lib.models.shape_infers.InferTinyCellNet import DynamicShapeTinyNet
|
||||
|
||||
def get_batch_jacobian(net, x, target, device, split_data):
|
||||
x.requires_grad_(True)
|
||||
|
@ -30,11 +31,9 @@ def get_batch_jacobian(net, x, target, device, split_data):
|
|||
en=(sp+1)*N//split_data
|
||||
y = net(x[st:en])
|
||||
|
||||
# TODO: We have to deal with different output styles of
|
||||
# different APIs properly
|
||||
# # natsbench sss produces (activation, logits) tuple
|
||||
# if isinstance(outputs, Tuple) and len(outputs) == 2:
|
||||
# outputs = outputs[1]
|
||||
# natsbench sss produces (activation, logits) tuple
|
||||
if isinstance(y, DynamicShapeTinyNet) and len(y) == 2:
|
||||
y = y[1]
|
||||
if isinstance(net, Model):
|
||||
y, aux_logits = y[0], y[1]
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ from . import measure
|
|||
from ..p_utils import get_layer_metric_array
|
||||
|
||||
from archai.nas.model import Model
|
||||
from archai.algos.natsbench.lib.models.shape_infers.InferTinyCellNet import DynamicShapeTinyNet
|
||||
|
||||
|
||||
@measure('plain', bn=True, mode='param')
|
||||
|
@ -34,11 +35,10 @@ def compute_plain_per_weight(net, inputs, targets, mode, loss_fn, split_data=1):
|
|||
en=(sp+1)*N//split_data
|
||||
|
||||
outputs = net.forward(inputs[st:en])
|
||||
# TODO: We have to deal with different output styles of
|
||||
# different APIs properly
|
||||
# # natsbench sss produces (activation, logits) tuple
|
||||
# if isinstance(outputs, Tuple) and len(outputs) == 2:
|
||||
# outputs = outputs[1]
|
||||
|
||||
# natsbench sss produces (activation, logits) tuple
|
||||
if isinstance(outputs, DynamicShapeTinyNet) and len(outputs) == 2:
|
||||
outputs = outputs[1]
|
||||
if isinstance(net, Model):
|
||||
outputs, aux_logits = outputs[0], outputs[1]
|
||||
loss = loss_fn(outputs, targets[st:en])
|
||||
|
|
|
@ -26,7 +26,7 @@ from . import measure
|
|||
from ..p_utils import get_layer_metric_array
|
||||
|
||||
from archai.nas.model import Model
|
||||
|
||||
from archai.algos.natsbench.lib.models.shape_infers.InferTinyCellNet import DynamicShapeTinyNet
|
||||
|
||||
def snip_forward_conv2d(self, x):
|
||||
return F.conv2d(x, self.weight * self.weight_mask, self.bias,
|
||||
|
@ -57,11 +57,10 @@ def compute_snip_per_weight(net, inputs, targets, mode, loss_fn, split_data=1):
|
|||
en=(sp+1)*N//split_data
|
||||
|
||||
outputs = net.forward(inputs[st:en])
|
||||
# TODO: We have to deal with different output styles of
|
||||
# different APIs properly
|
||||
# # natsbench sss produces (activation, logits) tuple
|
||||
# if isinstance(outputs, Tuple) and len(outputs) == 2:
|
||||
# outputs = outputs[1]
|
||||
|
||||
# natsbench sss produces (activation, logits) tuple
|
||||
if isinstance(outputs, DynamicShapeTinyNet) and len(outputs) == 2:
|
||||
outputs = outputs[1]
|
||||
if isinstance(net, Model):
|
||||
outputs, aux_logits = outputs[0], outputs[1]
|
||||
loss = loss_fn(outputs, targets[st:en])
|
||||
|
|
|
@ -21,7 +21,7 @@ from . import measure
|
|||
from ..p_utils import get_layer_metric_array
|
||||
|
||||
from archai.nas.model import Model
|
||||
|
||||
from archai.algos.natsbench.lib.models.shape_infers.InferTinyCellNet import DynamicShapeTinyNet
|
||||
|
||||
@measure('synflow', bn=False, mode='param')
|
||||
@measure('synflow_bn', bn=True, mode='param')
|
||||
|
@ -55,11 +55,9 @@ def compute_synflow_per_weight(net, inputs, targets, mode, split_data=1, loss_fn
|
|||
inputs = torch.ones([1] + input_dim).double().to(device)
|
||||
output = net.forward(inputs)
|
||||
|
||||
# TODO: We have to deal with different output styles of
|
||||
# different APIs properly
|
||||
# # natsbench sss produces (activation, logits) tuple
|
||||
# if isinstance(outputs, Tuple) and len(outputs) == 2:
|
||||
# outputs = outputs[1]
|
||||
# natsbench sss produces (activation, logits) tuple
|
||||
if isinstance(output, DynamicShapeTinyNet) and len(output) == 2:
|
||||
output = output[1]
|
||||
if isinstance(net, Model):
|
||||
output, aux_logits = output[0], output[1]
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче