remove get_op_desc, refactor random finalizers

This commit is contained in:
Shital Shah 2020-05-17 00:59:54 -07:00
Родитель 8b388d905c
Коммит cf0ce350fe
6 изменённых файлов: 32 добавлений и 79 удалений

Просмотреть файл

@ -1,4 +1,4 @@
from typing import Iterable, Optional, Tuple, List
from typing import Iterable, Optional, Tuple, List, Iterator
import torch
from torch import nn
@ -72,11 +72,9 @@ class MixedOp(Op):
def can_drop_path(self) -> bool:
return False
def get_op_desc(self, index:int)->OpDesc:
''' index: index in the primitives list '''
assert index < len(self.PRIMITIVES)
desc, _ = self._ops[index].finalize()
return desc
@overrides
def ops(self)->Iterator['Op']: # type: ignore
return iter(self._ops)
def _set_alphas(self, alphas: Iterable[nn.Parameter]) -> None:
# must call before adding other ops

Просмотреть файл

@ -1,4 +1,4 @@
from typing import Iterable, Optional, Tuple, List
from typing import Iterable, Optional, Tuple, List, Iterator
from collections import deque
import torch
@ -34,7 +34,7 @@ class DivOp(Op):
'none' # this must be at the end so top1 doesn't choose it
]
# list of primitive ops not allowed in the
# list of primitive ops not allowed in the
# diversity calculation
# NOTALLOWED = ['skip_connect', 'none']
NOTALLOWED = ['none']
@ -103,10 +103,10 @@ class DivOp(Op):
@property
def num_valid_div_ops(self)->int:
return len(self.PRIMITIVES) - len(self.NOTALLOWED)
@overrides
def forward(self, x):
# save activations to object
if self._collect_activations:
self._forward_counter += 1
@ -121,7 +121,7 @@ class DivOp(Op):
result = sum(w * op(x) for w, op in zip(asm, self._ops))
else:
result = sum(op(x) for op in self._ops)
return result
@overrides
@ -136,18 +136,14 @@ class DivOp(Op):
for w in op.parameters():
yield w
def get_op_desc(self, index:int)->OpDesc:
''' index: index in the primitives list '''
assert index < len(self.PRIMITIVES)
desc, _ = self._ops[index].finalize()
return desc
@overrides
def ops(self)->Iterator['Op']: # type: ignore
return iter(self._ops)
def get_valid_op_desc(self, index:int)->OpDesc:
''' index: index in the valid index list '''
assert index <= self.num_valid_div_ops
orig_index = self._valid_to_orig[index]
orig_index = self._valid_to_orig[index]
desc, _ = self._ops[orig_index].finalize()
return desc

Просмотреть файл

@ -1,7 +1,4 @@
from random import sample
from archai.common.utils import AverageMeter
from collections import defaultdict, deque
from typing import Iterable, Optional, Tuple, List
from typing import Iterable, Optional, Tuple, List, Iterator
import torch
from torch import nn
@ -118,13 +115,9 @@ class GsOp(Op):
def can_drop_path(self) -> bool:
return False
def get_op_desc(self, index:int)->OpDesc:
''' index: index in the primitives list '''
assert index < len(self.PRIMITIVES)
desc, _ = self._ops[index].finalize()
return desc
@overrides
def ops(self)->Iterator['Op']: # type: ignore
return iter(self._ops)
def _set_alphas(self, alphas: Iterable[nn.Parameter]) -> None:
# must call before adding other ops

Просмотреть файл

@ -1,5 +1,5 @@
from copy import deepcopy
from typing import Iterable, List, Optional, Sequence, Tuple, Mapping
from typing import Iterable, List, Optional, Iterator, Tuple, Mapping
import heapq
import copy
@ -180,11 +180,9 @@ class PetridishOp(Op):
# rank=None to indicate no further selection needed as in darts
return final_op_desc, None
def get_op_desc(self, index:int)->OpDesc:
''' index: index in the primitives list '''
assert index < len(self.PRIMITIVES)
desc, _ = self._ops[index].finalize()
return desc
@overrides
def ops(self)->Iterator['Op']: # type: ignore
return iter(self._ops)
def _set_alphas(self, alphas: Iterable[nn.Parameter], in_len:int) -> None:
assert len(list(self.parameters()))==0 # must call before adding other ops

Просмотреть файл

@ -1,5 +1,5 @@
from argparse import ArgumentError
from typing import Callable, Iterable, List, Mapping, Tuple, Dict, Optional, Union
from typing import Callable, Iterable, Iterator, List, Mapping, Tuple, Dict, Optional, Union
from abc import ABC, abstractmethod
import copy
@ -114,6 +114,10 @@ class Op(nn.Module, ABC, EnforceOverrides):
desc.trainables = copy.deepcopy(self.get_trainables())
return desc, None # desc, rank (None means op is unranked and cannot be removed)
def ops(self)->Iterator['Op']:
"""Return contituent ops, if this op is primitive just return self"""
yield self
# if op should not be dropped during drop path then return False
def can_drop_path(self)->bool:
return True

Просмотреть файл

@ -19,50 +19,14 @@ from archai.algos.divnas.divop import DivOp
class RandomFinalizers(Finalizers):
@overrides
def finalize_node(self, node:nn.ModuleList, max_final_edges:int)->NodeDesc:
# node is a list of edges
assert len(node) >= max_final_edges
# get total number of ops incoming to this node
num_ops = 0
for edge in node:
if hasattr(edge._op, 'PRIMITIVES'):
num_ops += len(edge._op.PRIMITIVES)
in_ops = [(edge,op) for edge in node for op in edge._op.ops()]
assert len(in_ops) >= max_final_edges
# and collect some bookkeeping indices
edge_num_and_op_ind = []
for j, edge in enumerate(node):
if hasattr(edge._op, 'PRIMITIVES'):
for k in range(len(edge._op.PRIMITIVES)):
edge_num_and_op_ind.append((j, k))
assert len(edge_num_and_op_ind) == num_ops
# run random subset selection
rand_subset = self._random_subset(num_ops, max_final_edges)
# convert the cov indices to edge descs
selected_edges = []
for ind in rand_subset:
edge_ind, op_ind = edge_num_and_op_ind[ind]
op_desc = node[edge_ind]._op.get_op_desc(op_ind)
new_edge = EdgeDesc(op_desc, node[edge_ind].input_ids)
selected_edges.append(new_edge)
selected = random.sample(in_ops, max_final_edges)
# finalize selected op, select 1st value from return which is op finalized desc
selected_edges = [EdgeDesc(s[1].finalize()[0], s[0].input_ids) \
for s in selected]
return NodeDesc(selected_edges)
def _random_subset(self, num_ops:int, max_final_edges:int)->Set[int]:
assert num_ops > 0
assert max_final_edges > 0
assert max_final_edges <= num_ops
S = set()
while len(S) < max_final_edges:
sample = random.randint(0, num_ops)
S.add(sample)
return S