Local search refactor works nominally on local box. Needs to be tested that it reproduces results.

This commit is contained in:
Debadeepta Dey 2021-10-28 20:01:09 -07:00 коммит произвёл Gustavo Rosa
Родитель a5c8e495e6
Коммит 1505de5b14
2 изменённых файлов: 24 добавлений и 14 удалений

Просмотреть файл

@ -3,27 +3,25 @@ import torch.nn as nn
class ArchWithMetaData:
def __init__(self, arch:nn.Module, metadata:Dict):
assert isinstance(arch, nn.Module)
assert isinstance(metadata, dict)
self.arch_model = arch
self.metadata = metadata
def __init__(self, model:nn.Module, extradata:Dict):
self._arch = model
self._metadata = extradata
@property
def arch(self):
return self.arch_model
return self._arch
@arch.setter
def arch(self, arch:nn.Module):
assert isinstance(arch, nn.Module)
self.arch_model = arch
def arch(self, model:nn.Module):
assert isinstance(model, nn.Module)
self._arch = model
@property
def metadata(self):
return self.metadata
return self._metadata
@metadata.setter
def metadata(self, metadata):
assert isinstance(metadata, dict)
self.metadata = metadata
def metadata(self, extradata:Dict):
assert isinstance(extradata, dict)
self._metadata = extradata

Просмотреть файл

@ -69,4 +69,16 @@ class DiscreteSearchSpaceNatsbenchTSS(DiscreteSearchSpace):
# given a string, get the list of operations
tokens = string.split('|')
ops = [t.split('~')[0] for i,t in enumerate(tokens) if i not in [0,2,5,9]]
return ops
return ops
def _get_string_from_ops(self, ops):
''' Reused from https://github.com/naszilla/naszilla/blob/master/naszilla/nas_bench_201/cell_201.py '''
# given a list of operations, get the string
strings = ['|']
nodes = [0, 0, 1, 0, 1, 2]
for i, op in enumerate(ops):
strings.append(op+'~{}|'.format(nodes[i]))
if i < len(nodes) - 1 and nodes[i+1] == 0:
strings.append('+|')
return ''.join(strings)