зеркало из https://github.com/microsoft/archai.git
remove get_op_desc, refactor random finalizers
This commit is contained in:
Родитель
8b388d905c
Коммит
cf0ce350fe
|
@ -1,4 +1,4 @@
|
|||
from typing import Iterable, Optional, Tuple, List
|
||||
from typing import Iterable, Optional, Tuple, List, Iterator
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
@ -72,11 +72,9 @@ class MixedOp(Op):
|
|||
def can_drop_path(self) -> bool:
|
||||
return False
|
||||
|
||||
def get_op_desc(self, index:int)->OpDesc:
|
||||
''' index: index in the primitives list '''
|
||||
assert index < len(self.PRIMITIVES)
|
||||
desc, _ = self._ops[index].finalize()
|
||||
return desc
|
||||
@overrides
|
||||
def ops(self)->Iterator['Op']: # type: ignore
|
||||
return iter(self._ops)
|
||||
|
||||
def _set_alphas(self, alphas: Iterable[nn.Parameter]) -> None:
|
||||
# must call before adding other ops
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Iterable, Optional, Tuple, List
|
||||
from typing import Iterable, Optional, Tuple, List, Iterator
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
|
@ -34,7 +34,7 @@ class DivOp(Op):
|
|||
'none' # this must be at the end so top1 doesn't choose it
|
||||
]
|
||||
|
||||
# list of primitive ops not allowed in the
|
||||
# list of primitive ops not allowed in the
|
||||
# diversity calculation
|
||||
# NOTALLOWED = ['skip_connect', 'none']
|
||||
NOTALLOWED = ['none']
|
||||
|
@ -103,10 +103,10 @@ class DivOp(Op):
|
|||
@property
|
||||
def num_valid_div_ops(self)->int:
|
||||
return len(self.PRIMITIVES) - len(self.NOTALLOWED)
|
||||
|
||||
|
||||
@overrides
|
||||
def forward(self, x):
|
||||
|
||||
|
||||
# save activations to object
|
||||
if self._collect_activations:
|
||||
self._forward_counter += 1
|
||||
|
@ -121,7 +121,7 @@ class DivOp(Op):
|
|||
result = sum(w * op(x) for w, op in zip(asm, self._ops))
|
||||
else:
|
||||
result = sum(op(x) for op in self._ops)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
@overrides
|
||||
|
@ -136,18 +136,14 @@ class DivOp(Op):
|
|||
for w in op.parameters():
|
||||
yield w
|
||||
|
||||
|
||||
def get_op_desc(self, index:int)->OpDesc:
|
||||
''' index: index in the primitives list '''
|
||||
assert index < len(self.PRIMITIVES)
|
||||
desc, _ = self._ops[index].finalize()
|
||||
return desc
|
||||
|
||||
@overrides
|
||||
def ops(self)->Iterator['Op']: # type: ignore
|
||||
return iter(self._ops)
|
||||
|
||||
def get_valid_op_desc(self, index:int)->OpDesc:
|
||||
''' index: index in the valid index list '''
|
||||
assert index <= self.num_valid_div_ops
|
||||
orig_index = self._valid_to_orig[index]
|
||||
orig_index = self._valid_to_orig[index]
|
||||
desc, _ = self._ops[orig_index].finalize()
|
||||
return desc
|
||||
|
||||
|
|
|
@ -1,7 +1,4 @@
|
|||
from random import sample
|
||||
from archai.common.utils import AverageMeter
|
||||
from collections import defaultdict, deque
|
||||
from typing import Iterable, Optional, Tuple, List
|
||||
from typing import Iterable, Optional, Tuple, List, Iterator
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
@ -118,13 +115,9 @@ class GsOp(Op):
|
|||
def can_drop_path(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def get_op_desc(self, index:int)->OpDesc:
|
||||
''' index: index in the primitives list '''
|
||||
assert index < len(self.PRIMITIVES)
|
||||
desc, _ = self._ops[index].finalize()
|
||||
return desc
|
||||
|
||||
@overrides
|
||||
def ops(self)->Iterator['Op']: # type: ignore
|
||||
return iter(self._ops)
|
||||
|
||||
def _set_alphas(self, alphas: Iterable[nn.Parameter]) -> None:
|
||||
# must call before adding other ops
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from copy import deepcopy
|
||||
from typing import Iterable, List, Optional, Sequence, Tuple, Mapping
|
||||
from typing import Iterable, List, Optional, Iterator, Tuple, Mapping
|
||||
import heapq
|
||||
import copy
|
||||
|
||||
|
@ -180,11 +180,9 @@ class PetridishOp(Op):
|
|||
# rank=None to indicate no further selection needed as in darts
|
||||
return final_op_desc, None
|
||||
|
||||
def get_op_desc(self, index:int)->OpDesc:
|
||||
''' index: index in the primitives list '''
|
||||
assert index < len(self.PRIMITIVES)
|
||||
desc, _ = self._ops[index].finalize()
|
||||
return desc
|
||||
@overrides
|
||||
def ops(self)->Iterator['Op']: # type: ignore
|
||||
return iter(self._ops)
|
||||
|
||||
def _set_alphas(self, alphas: Iterable[nn.Parameter], in_len:int) -> None:
|
||||
assert len(list(self.parameters()))==0 # must call before adding other ops
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from argparse import ArgumentError
|
||||
from typing import Callable, Iterable, List, Mapping, Tuple, Dict, Optional, Union
|
||||
from typing import Callable, Iterable, Iterator, List, Mapping, Tuple, Dict, Optional, Union
|
||||
from abc import ABC, abstractmethod
|
||||
import copy
|
||||
|
||||
|
@ -114,6 +114,10 @@ class Op(nn.Module, ABC, EnforceOverrides):
|
|||
desc.trainables = copy.deepcopy(self.get_trainables())
|
||||
return desc, None # desc, rank (None means op is unranked and cannot be removed)
|
||||
|
||||
def ops(self)->Iterator['Op']:
|
||||
"""Return contituent ops, if this op is primitive just return self"""
|
||||
yield self
|
||||
|
||||
# if op should not be dropped during drop path then return False
|
||||
def can_drop_path(self)->bool:
|
||||
return True
|
||||
|
|
|
@ -19,50 +19,14 @@ from archai.algos.divnas.divop import DivOp
|
|||
|
||||
|
||||
class RandomFinalizers(Finalizers):
|
||||
|
||||
|
||||
@overrides
|
||||
def finalize_node(self, node:nn.ModuleList, max_final_edges:int)->NodeDesc:
|
||||
# node is a list of edges
|
||||
assert len(node) >= max_final_edges
|
||||
|
||||
# get total number of ops incoming to this node
|
||||
num_ops = 0
|
||||
for edge in node:
|
||||
if hasattr(edge._op, 'PRIMITIVES'):
|
||||
num_ops += len(edge._op.PRIMITIVES)
|
||||
in_ops = [(edge,op) for edge in node for op in edge._op.ops()]
|
||||
assert len(in_ops) >= max_final_edges
|
||||
|
||||
# and collect some bookkeeping indices
|
||||
edge_num_and_op_ind = []
|
||||
for j, edge in enumerate(node):
|
||||
if hasattr(edge._op, 'PRIMITIVES'):
|
||||
for k in range(len(edge._op.PRIMITIVES)):
|
||||
edge_num_and_op_ind.append((j, k))
|
||||
|
||||
assert len(edge_num_and_op_ind) == num_ops
|
||||
|
||||
# run random subset selection
|
||||
rand_subset = self._random_subset(num_ops, max_final_edges)
|
||||
|
||||
# convert the cov indices to edge descs
|
||||
selected_edges = []
|
||||
for ind in rand_subset:
|
||||
edge_ind, op_ind = edge_num_and_op_ind[ind]
|
||||
op_desc = node[edge_ind]._op.get_op_desc(op_ind)
|
||||
new_edge = EdgeDesc(op_desc, node[edge_ind].input_ids)
|
||||
selected_edges.append(new_edge)
|
||||
|
||||
selected = random.sample(in_ops, max_final_edges)
|
||||
# finalize selected op, select 1st value from return which is op finalized desc
|
||||
selected_edges = [EdgeDesc(s[1].finalize()[0], s[0].input_ids) \
|
||||
for s in selected]
|
||||
return NodeDesc(selected_edges)
|
||||
|
||||
|
||||
def _random_subset(self, num_ops:int, max_final_edges:int)->Set[int]:
|
||||
assert num_ops > 0
|
||||
assert max_final_edges > 0
|
||||
assert max_final_edges <= num_ops
|
||||
|
||||
S = set()
|
||||
while len(S) < max_final_edges:
|
||||
sample = random.randint(0, num_ops)
|
||||
S.add(sample)
|
||||
|
||||
return S
|
||||
|
|
Загрузка…
Ссылка в новой задаче