зеркало из https://github.com/microsoft/archai.git
Local search refactor works nominally on local box. Needs to be tested that it reproduces results.
This commit is contained in:
Родитель
a5c8e495e6
Коммит
1505de5b14
|
@ -3,27 +3,25 @@ import torch.nn as nn
|
|||
|
||||
|
||||
class ArchWithMetaData:
|
||||
def __init__(self, arch:nn.Module, metadata:Dict):
|
||||
assert isinstance(arch, nn.Module)
|
||||
assert isinstance(metadata, dict)
|
||||
self.arch_model = arch
|
||||
self.metadata = metadata
|
||||
def __init__(self, model:nn.Module, extradata:Dict):
|
||||
self._arch = model
|
||||
self._metadata = extradata
|
||||
|
||||
@property
|
||||
def arch(self):
|
||||
return self.arch_model
|
||||
return self._arch
|
||||
|
||||
@arch.setter
|
||||
def arch(self, arch:nn.Module):
|
||||
assert isinstance(arch, nn.Module)
|
||||
self.arch_model = arch
|
||||
def arch(self, model:nn.Module):
|
||||
assert isinstance(model, nn.Module)
|
||||
self._arch = model
|
||||
|
||||
@property
|
||||
def metadata(self):
|
||||
return self.metadata
|
||||
return self._metadata
|
||||
|
||||
@metadata.setter
|
||||
def metadata(self, metadata):
|
||||
assert isinstance(metadata, dict)
|
||||
self.metadata = metadata
|
||||
def metadata(self, extradata:Dict):
|
||||
assert isinstance(extradata, dict)
|
||||
self._metadata = extradata
|
||||
|
|
@ -69,4 +69,16 @@ class DiscreteSearchSpaceNatsbenchTSS(DiscreteSearchSpace):
|
|||
# given a string, get the list of operations
|
||||
tokens = string.split('|')
|
||||
ops = [t.split('~')[0] for i,t in enumerate(tokens) if i not in [0,2,5,9]]
|
||||
return ops
|
||||
return ops
|
||||
|
||||
|
||||
def _get_string_from_ops(self, ops):
|
||||
''' Reused from https://github.com/naszilla/naszilla/blob/master/naszilla/nas_bench_201/cell_201.py '''
|
||||
# given a list of operations, get the string
|
||||
strings = ['|']
|
||||
nodes = [0, 0, 1, 0, 1, 2]
|
||||
for i, op in enumerate(ops):
|
||||
strings.append(op+'~{}|'.format(nodes[i]))
|
||||
if i < len(nodes) - 1 and nodes[i+1] == 0:
|
||||
strings.append('+|')
|
||||
return ''.join(strings)
|
Загрузка…
Ссылка в новой задаче