[Retiarii] Coding style improvements for pylint and flake8 (#3190)

This commit is contained in:
Yuge Zhang 2020-12-14 18:37:39 +08:00 коммит произвёл GitHub
Родитель 593a275c92
Коммит 59cd398299
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
34 изменённых файлов: 221 добавлений и 199 удалений

Просмотреть файл

@ -1,29 +1,28 @@
import logging
from typing import *
from typing import List
from ..graph import IllegalGraphError, Edge, Graph, Node, Model
from ..operation import Operation, Cell
_logger = logging.getLogger(__name__)
def model_to_pytorch_script(model: Model, placement = None) -> str:
def model_to_pytorch_script(model: Model, placement=None) -> str:
graphs = []
total_pkgs = set()
for name, cell in model.graphs.items():
import_pkgs, graph_code = graph_to_pytorch_model(name, cell, placement = placement)
import_pkgs, graph_code = graph_to_pytorch_model(name, cell, placement=placement)
graphs.append(graph_code)
total_pkgs.update(import_pkgs)
pkgs_code = '\n'.join(['import {}'.format(pkg) for pkg in total_pkgs])
return _PyTorchScriptTemplate.format(pkgs_code, '\n\n'.join(graphs)).strip()
def _sorted_incoming_edges(node: Node) -> List[Edge]:
edges = [edge for edge in node.graph.edges if edge.tail is node]
_logger.info('sorted_incoming_edges: {}'.format(edges))
_logger.info('sorted_incoming_edges: %s', str(edges))
if not edges:
return []
_logger.info(f'all tail_slots are None: {[edge.tail_slot for edge in edges]}')
_logger.info('all tail_slots are None: %s', str([edge.tail_slot for edge in edges]))
if all(edge.tail_slot is None for edge in edges):
return edges
if all(isinstance(edge.tail_slot, int) for edge in edges):
@ -32,6 +31,7 @@ def _sorted_incoming_edges(node: Node) -> List[Edge]:
return edges
raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name))
def _format_inputs(node: Node) -> List[str]:
edges = _sorted_incoming_edges(node)
inputs = []
@ -53,6 +53,7 @@ def _format_inputs(node: Node) -> List[str]:
inputs.append('{}[{}]'.format(edge.head.name, edge.head_slot))
return inputs
def _remove_prefix(names, graph_name):
"""
variables name (full name space) is too long,
@ -69,14 +70,14 @@ def _remove_prefix(names, graph_name):
else:
return names[len(graph_name):] if names.startswith(graph_name) else names
def graph_to_pytorch_model(graph_name: str, graph: Graph, placement = None) -> str:
def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str:
nodes = graph.topo_sort()
# handle module node and function node differently
# only need to generate code for module here
import_pkgs = set()
node_codes = []
placement_codes = []
for node in nodes:
if node.operation:
pkg_name = node.operation.get_import_pkg()

Просмотреть файл

@ -1,2 +1 @@
from .graph_gen import convert_to_graph
from .visualize import visualize_model

Просмотреть файл

@ -1,14 +1,13 @@
import json_tricks
import logging
import re
import torch
from ..graph import Graph, Node, Edge, Model
from ..operation import Cell, Operation
from ..nn.pytorch import Placeholder, LayerChoice, InputChoice
from .op_types import MODULE_EXCEPT_LIST, OpTypeName, BasicOpsPT
from .utils import build_full_name, _convert_name
from ..graph import Graph, Model, Node
from ..nn.pytorch import InputChoice, LayerChoice, Placeholder
from ..operation import Cell
from .op_types import MODULE_EXCEPT_LIST, BasicOpsPT, OpTypeName
from .utils import _convert_name, build_full_name
_logger = logging.getLogger(__name__)
@ -16,6 +15,7 @@ global_seq = 0
global_graph_id = 0
modules_arg = None
def _add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap, ignore_first=False):
"""
Parameters
@ -76,6 +76,7 @@ def _add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap,
new_node_input_idx += 1
def create_prim_constant_node(ir_graph, node, module_name):
global global_seq
attrs = {}
@ -86,14 +87,17 @@ def create_prim_constant_node(ir_graph, node, module_name):
node.kind(), attrs)
return new_node
def handle_prim_attr_node(node):
assert node.hasAttribute('name')
attrs = {'name': node.s('name'), 'input': node.inputsAt(0).debugName()}
return node.kind(), attrs
def _remove_mangle(module_type_str):
return re.sub('\\.___torch_mangle_\\d+', '', module_type_str)
def remove_unconnected_nodes(ir_graph, targeted_type=None):
"""
Parameters
@ -122,6 +126,7 @@ def remove_unconnected_nodes(ir_graph, targeted_type=None):
for hidden_node in to_removes:
hidden_node.remove()
def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, ir_graph):
"""
Convert torch script node to our node ir, and build our graph ir
@ -156,7 +161,7 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
# TODO: add scope name
ir_graph._add_input(_convert_name(_input.debugName()))
node_index = {} # graph node to graph ir node
node_index = {} # graph node to graph ir node
# some node does not have output but it modifies a variable, for example aten::append
# %17 : Tensor[] = aten::append(%out.1, %16)
@ -248,13 +253,14 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
# therefore, we do this check for a module. example below:
# %25 : __torch__.xxx = prim::GetAttr[name="input_switch"](%self)
# %27 : Tensor = prim::CallMethod[name="forward"](%25, %out.1)
assert submodule_name in script_module._modules, "submodule_name: {} not in script_module {}".format(submodule_name, script_module._modules.keys())
assert submodule_name in script_module._modules, "submodule_name: {} not in script_module {}".format(
submodule_name, script_module._modules.keys())
submodule_full_name = build_full_name(module_name, submodule_name)
submodule_obj = getattr(module, submodule_name)
subgraph, sub_m_attrs = convert_module(script_module._modules[submodule_name],
submodule_obj,
submodule_full_name, ir_model)
submodule_obj,
submodule_full_name, ir_model)
else:
# %8 : __torch__.nni.retiarii.model_apis.nn.___torch_mangle_37.ModuleList = prim::GetAttr[name="cells"](%self)
# %10 : __torch__.darts_model.Cell = prim::GetAttr[name="0"](%8)
@ -271,7 +277,7 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
predecessor_obj = getattr(module, predecessor_name)
submodule_obj = getattr(predecessor_obj, submodule_name)
subgraph, sub_m_attrs = convert_module(script_module._modules[predecessor_name]._modules[submodule_name],
submodule_obj, submodule_full_name, ir_model)
submodule_obj, submodule_full_name, ir_model)
else:
raise RuntimeError('Unsupported module case: {}'.format(submodule.inputsAt(0).type().str()))
@ -329,7 +335,7 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
node_type, attrs = handle_prim_attr_node(node)
global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Attr, global_seq),
node_type, attrs)
node_type, attrs)
node_index[node] = new_node
elif node.kind() == 'prim::min':
print('zql: ', sm_graph)
@ -350,6 +356,7 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
return node_index
def merge_aten_slices(ir_graph):
"""
if there is aten::slice node, merge the consecutive ones together.
@ -367,7 +374,7 @@ def merge_aten_slices(ir_graph):
break
if has_slice_node:
assert head_slice_nodes
for head_node in head_slice_nodes:
slot = 0
new_slice_node = ir_graph.add_node(build_full_name(head_node.name, 'merged'), OpTypeName.MergedSlice)
@ -391,11 +398,11 @@ def merge_aten_slices(ir_graph):
slot += 4
ir_graph.hidden_nodes.remove(node)
node = suc_node
for edge in node.outgoing_edges:
edge.head = new_slice_node
ir_graph.hidden_nodes.remove(node)
def refine_graph(ir_graph):
"""
@ -408,13 +415,14 @@ def refine_graph(ir_graph):
remove_unconnected_nodes(ir_graph, targeted_type='prim::GetAttr')
merge_aten_slices(ir_graph)
def _handle_layerchoice(module):
global modules_arg
m_attrs = {}
candidates = module.candidate_ops
choices = []
for i, cand in enumerate(candidates):
for cand in candidates:
assert id(cand) in modules_arg, 'id not exist: {}'.format(id(cand))
assert isinstance(modules_arg[id(cand)], dict)
cand_type = '__torch__.' + cand.__class__.__module__ + '.' + cand.__class__.__name__
@ -423,6 +431,7 @@ def _handle_layerchoice(module):
m_attrs['label'] = module.label
return m_attrs
def _handle_inputchoice(module):
m_attrs = {}
m_attrs['n_chosen'] = module.n_chosen
@ -430,6 +439,7 @@ def _handle_inputchoice(module):
m_attrs['label'] = module.label
return m_attrs
def convert_module(script_module, module, module_name, ir_model):
"""
Convert a module to its graph ir (i.e., Graph) along with its input arguments
@ -503,10 +513,11 @@ def convert_module(script_module, module, module_name, ir_model):
# TODO: if we parse this module, it means we will create a graph (module class)
# for this module. Then it is not necessary to record this module's arguments
# return ir_graph, modules_arg[id(module)].
# That is, we can refactor this part, to allow users to annotate which module
# That is, we can refactor this part, to allow users to annotate which module
# should not be parsed further.
return ir_graph, {}
def convert_to_graph(script_module, module, recorded_modules_arg):
"""
Convert module to our graph ir, i.e., build a ```Model``` type

Просмотреть файл

@ -16,6 +16,7 @@ class OpTypeName(str, Enum):
Placeholder = 'Placeholder'
MergedSlice = 'MergedSlice'
# deal with aten op
BasicOpsPT = {
'aten::mean': 'Mean',
@ -29,7 +30,7 @@ BasicOpsPT = {
'aten::size': 'Size',
'aten::view': 'View',
'aten::eq': 'Eq',
'aten::add_': 'Add_' # %out.3 : Tensor = aten::add_(%out.1, %connection.1, %4)
'aten::add_': 'Add_' # %out.3 : Tensor = aten::add_(%out.1, %connection.1, %4)
}
BasicOpsTF = {}
BasicOpsTF = {}

Просмотреть файл

@ -6,6 +6,7 @@ def build_full_name(prefix, name, seq=None):
else:
return '{}__{}{}'.format(prefix, name, str(seq))
def _convert_name(name: str) -> str:
"""
Convert the names using separator '.' to valid variable name in code

Просмотреть файл

@ -1,5 +1,6 @@
import graphviz
def convert_to_visualize(graph_ir, vgraph):
for name, graph in graph_ir.items():
if name == '_training_config':
@ -33,7 +34,8 @@ def convert_to_visualize(graph_ir, vgraph):
dst = cell_node[dst][0]
subgraph.edge(src, dst)
def visualize_model(graph_ir):
vgraph = graphviz.Digraph('G', filename='vgraph', format='jpg')
convert_to_visualize(graph_ir, vgraph)
vgraph.render()
vgraph.render()

Просмотреть файл

@ -1,12 +1,11 @@
import time
import os
import importlib.util
from typing import *
from typing import List
from ..graph import Model, ModelStatus
from .base import BaseExecutionEngine
from .cgo_engine import CGOExecutionEngine
from .interface import *
from .interface import AbstractExecutionEngine, WorkerInfo
from .listener import DefaultListener
_execution_engine = None

Просмотреть файл

@ -1,5 +1,5 @@
import logging
from typing import *
from typing import Dict, Any, List
from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo
from .. import codegen, utils
@ -61,16 +61,16 @@ class BaseExecutionEngine(AbstractExecutionEngine):
def _send_trial_callback(self, paramater: dict) -> None:
for listener in self._listeners:
_logger.warning('resources: {}'.format(listener.resources))
_logger.warning('resources: %s', listener.resources)
if not listener.has_available_resource():
_logger.warning('There is no available resource, but trial is submitted.')
listener.on_resource_used(1)
_logger.warning('on_resource_used: {}'.format(listener.resources))
_logger.warning('on_resource_used: %s', listener.resources)
def _request_trial_jobs_callback(self, num_trials: int) -> None:
for listener in self._listeners:
listener.on_resource_available(1 * num_trials)
_logger.warning('on_resource_available: {}'.format(listener.resources))
_logger.warning('on_resource_available: %s', listener.resources)
def _trial_end_callback(self, trial_id: int, success: bool) -> None:
model = self._running_models[trial_id]

Просмотреть файл

@ -1,6 +1,5 @@
import logging
import json
from typing import *
from typing import List, Dict, Tuple
from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo
from .. import codegen, utils
@ -12,8 +11,10 @@ from .logical_optimizer.opt_dedup_input import DedupInputOptimizer
from .base import BaseGraphData
_logger = logging.getLogger(__name__)
class CGOExecutionEngine(AbstractExecutionEngine):
def __init__(self, n_model_per_graph = 4) -> None:
def __init__(self, n_model_per_graph=4) -> None:
self._listeners: List[AbstractGraphListener] = []
self._running_models: Dict[int, Model] = dict()
self.logical_plan_counter = 0
@ -30,38 +31,37 @@ class CGOExecutionEngine(AbstractExecutionEngine):
advisor.intermediate_metric_callback = self._intermediate_metric_callback
advisor.final_metric_callback = self._final_metric_callback
def add_optimizer(self, opt):
self._optimizers.append(opt)
def submit_models(self, *models: List[Model]) -> None:
_logger.info(f'{len(models)} Models are submitted')
_logger.info('%d models are submitted', len(models))
logical = self._build_logical(models)
for opt in self._optimizers:
opt.convert(logical)
phy_models_and_placements = self._assemble(logical)
for model, placement, grouped_models in phy_models_and_placements:
data = BaseGraphData(codegen.model_to_pytorch_script(model, placement=placement),
model.training_config.module, model.training_config.kwargs)
model.training_config.module, model.training_config.kwargs)
for m in grouped_models:
self._original_models[m.model_id] = m
self._original_model_to_multi_model[m.model_id] = model
self._running_models[send_trial(data.dump())] = model
# for model in models:
# data = BaseGraphData(codegen.model_to_pytorch_script(model),
# model.config['trainer_module'], model.config['trainer_kwargs'])
# self._running_models[send_trial(data.dump())] = model
def _assemble(self, logical_plan : LogicalPlan) -> List[Tuple[Model, PhysicalDevice]]:
def _assemble(self, logical_plan: LogicalPlan) -> List[Tuple[Model, PhysicalDevice]]:
# unique_models = set()
# for node in logical_plan.graph.nodes:
# if node.graph.model not in unique_models:
# unique_models.add(node.graph.model)
# return [m for m in unique_models]
grouped_models : List[Dict[Model, PhysicalDevice]] = AssemblePolicy().group(logical_plan)
grouped_models: List[Dict[Model, PhysicalDevice]] = AssemblePolicy().group(logical_plan)
phy_models_and_placements = []
for multi_model in grouped_models:
model, model_placement = logical_plan.assemble(multi_model)
@ -69,7 +69,7 @@ class CGOExecutionEngine(AbstractExecutionEngine):
return phy_models_and_placements
def _build_logical(self, models: List[Model]) -> LogicalPlan:
logical_plan = LogicalPlan(id = self.logical_plan_counter)
logical_plan = LogicalPlan(plan_id=self.logical_plan_counter)
for model in models:
logical_plan.add_model(model)
self.logical_plan_counter += 1
@ -108,7 +108,7 @@ class CGOExecutionEngine(AbstractExecutionEngine):
for model_id in merged_metrics:
int_model_id = int(model_id)
self._original_models[int_model_id].intermediate_metrics.append(merged_metrics[model_id])
#model.intermediate_metrics.append(metrics)
# model.intermediate_metrics.append(metrics)
for listener in self._listeners:
listener.on_intermediate_metric(self._original_models[int_model_id], merged_metrics[model_id])
@ -117,10 +117,9 @@ class CGOExecutionEngine(AbstractExecutionEngine):
for model_id in merged_metrics:
int_model_id = int(model_id)
self._original_models[int_model_id].intermediate_metrics.append(merged_metrics[model_id])
#model.intermediate_metrics.append(metrics)
# model.intermediate_metrics.append(metrics)
for listener in self._listeners:
listener.on_metric(self._original_models[int_model_id], merged_metrics[model_id])
def query_available_resource(self) -> List[WorkerInfo]:
raise NotImplementedError # move the method from listener to here?
@ -141,6 +140,7 @@ class CGOExecutionEngine(AbstractExecutionEngine):
trainer_instance = trainer_cls(model_cls(), graph_data.training_kwargs)
trainer_instance.fit()
class AssemblePolicy:
@staticmethod
def group(logical_plan):
@ -148,4 +148,3 @@ class AssemblePolicy:
for idx, m in enumerate(logical_plan.models):
group_model[m] = PhysicalDevice('server', f'cuda:{idx}')
return [group_model]

Просмотреть файл

@ -1,5 +1,5 @@
from abc import *
from typing import *
from abc import ABC, abstractmethod, abstractclassmethod
from typing import Any, NewType, List
from ..graph import Model, MetricData

Просмотреть файл

@ -1,7 +1,5 @@
from typing import *
from ..graph import *
from .interface import *
from ..graph import Model, ModelStatus
from .interface import MetricData, AbstractGraphListener
class DefaultListener(AbstractGraphListener):

Просмотреть файл

@ -1,8 +1,8 @@
from abc import *
from typing import *
from abc import ABC
from .logical_plan import LogicalPlan
class AbstractOptimizer(ABC):
def __init__(self) -> None:
pass

Просмотреть файл

@ -1,9 +1,8 @@
from nni.retiarii.operation import Operation
from nni.retiarii.graph import Model, Graph, Edge, Node, Cell
from typing import *
import logging
from nni.retiarii.operation import _IOPseudoOperation
import copy
from typing import Dict, Tuple, List, Any
from ...graph import Cell, Edge, Graph, Model, Node
from ...operation import Operation, _IOPseudoOperation
class PhysicalDevice:
@ -108,11 +107,11 @@ class OriginNode(AbstractLogicalNode):
class LogicalPlan:
def __init__(self, id=0) -> None:
def __init__(self, plan_id=0) -> None:
self.lp_model = Model(_internal=True)
self.id = id
self.id = plan_id
self.logical_graph = LogicalGraph(
self.lp_model, id, name=f'{id}', _internal=True)._register()
self.lp_model, self.id, name=f'{self.id}', _internal=True)._register()
self.lp_model._root_graph_name = self.logical_graph.name
self.models = []
@ -148,7 +147,7 @@ class LogicalPlan:
phy_model.training_config.kwargs['is_multi_model'] = True
phy_model.training_config.kwargs['model_cls'] = phy_graph.name
phy_model.training_config.kwargs['model_kwargs'] = []
#FIXME: allow user to specify
# FIXME: allow user to specify
phy_model.training_config.module = 'nni.retiarii.trainer.PyTorchMultiModelTrainer'
# merge sub-graphs
@ -158,10 +157,9 @@ class LogicalPlan:
model.graphs[graph_name]._fork_to(
phy_model, name_prefix=f'M_{model.model_id}_')
# When replace logical nodes, merge the training configs when
# input/output nodes are replaced.
training_config_slot = {} # Model ID -> Slot ID
training_config_slot = {} # Model ID -> Slot ID
input_slot_mapping = {}
output_slot_mapping = {}
# Replace all logical nodes to executable physical nodes
@ -230,7 +228,7 @@ class LogicalPlan:
to_node = copied_op[(edge.head, tail_placement)]
else:
to_operation = Operation.new(
'ToDevice', {"device":tail_placement.device})
'ToDevice', {"device": tail_placement.device})
to_node = Node(phy_graph, phy_model._uid(),
edge.head.name+"_to_"+edge.tail.name, to_operation)._register()
Edge((edge.head, edge.head_slot),
@ -249,19 +247,18 @@ class LogicalPlan:
if edge.head in input_nodes:
edge.head_slot = input_slot_mapping[edge.head]
edge.head = phy_graph.input_node
# merge all output nodes into one with multiple slots
output_nodes = []
for node in phy_graph.hidden_nodes:
if isinstance(node.operation, _IOPseudoOperation) and node.operation.type == '_outputs':
output_nodes.append(node)
for edge in phy_graph.edges:
if edge.tail in output_nodes:
edge.tail_slot = output_slot_mapping[edge.tail]
edge.tail = phy_graph.output_node
for node in input_nodes:
node.remove()
for node in output_nodes:

Просмотреть файл

@ -1,10 +0,0 @@
from .base_optimizer import BaseOptimizer
from .logical_plan import LogicalPlan
class BatchingOptimizer(BaseOptimizer):
def __init__(self) -> None:
pass
def convert(self, logical_plan: LogicalPlan) -> None:
pass

Просмотреть файл

@ -1,32 +1,33 @@
from .interface import AbstractOptimizer
from .logical_plan import LogicalPlan, AbstractLogicalNode, LogicalGraph, OriginNode, PhysicalDevice
from nni.retiarii import Graph, Node, Model
from typing import *
from nni.retiarii.operation import _IOPseudoOperation
from typing import List, Dict, Tuple
from ...graph import Graph, Model, Node
from .interface import AbstractOptimizer
from .logical_plan import (AbstractLogicalNode, LogicalGraph, LogicalPlan,
OriginNode, PhysicalDevice)
_supported_training_modules = ['nni.retiarii.trainer.PyTorchImageClassificationTrainer']
class DedupInputNode(AbstractLogicalNode):
def __init__(self, logical_graph : LogicalGraph, id : int, \
nodes_to_dedup : List[Node], _internal=False):
super().__init__(logical_graph, id, \
"Dedup_"+nodes_to_dedup[0].name, \
nodes_to_dedup[0].operation)
self.origin_nodes : List[OriginNode] = nodes_to_dedup.copy()
def __init__(self, logical_graph: LogicalGraph, node_id: int,
nodes_to_dedup: List[Node], _internal=False):
super().__init__(logical_graph, node_id,
"Dedup_"+nodes_to_dedup[0].name,
nodes_to_dedup[0].operation)
self.origin_nodes: List[OriginNode] = nodes_to_dedup.copy()
def assemble(self, multi_model_placement: Dict[Model, PhysicalDevice]) -> Tuple[Node, PhysicalDevice]:
for node in self.origin_nodes:
if node.original_graph.model in multi_model_placement:
new_node = Node(node.original_graph, node.id, \
f'M_{node.original_graph.model.model_id}_{node.name}', \
node.operation)
new_node = Node(node.original_graph, node.id,
f'M_{node.original_graph.model.model_id}_{node.name}',
node.operation)
return new_node, multi_model_placement[node.original_graph.model]
raise ValueError(f'DedupInputNode {self.name} does not contain nodes from multi_model')
def _fork_to(self, graph: Graph):
DedupInputNode(graph, self.id, self.origin_nodes)._register()
def __repr__(self) -> str:
return f'DedupNode(id={self.id}, name={self.name}, \
len(nodes_to_dedup)={len(self.origin_nodes)}'
@ -35,6 +36,7 @@ class DedupInputNode(AbstractLogicalNode):
class DedupInputOptimizer(AbstractOptimizer):
def __init__(self) -> None:
pass
def _check_deduplicate_by_node(self, root_node, node_to_check):
if root_node == node_to_check:
return True
@ -50,13 +52,12 @@ class DedupInputOptimizer(AbstractOptimizer):
return False
else:
return False
def convert(self, logical_plan: LogicalPlan) -> None:
nodes_to_skip = set()
while True: # repeat until the logical_graph converges
while True: # repeat until the logical_graph converges
input_nodes = logical_plan.logical_graph.get_nodes_by_type("_inputs")
#_PseudoOperation(type_name="_inputs"))
# _PseudoOperation(type_name="_inputs"))
root_node = None
for node in input_nodes:
if node in nodes_to_skip:
@ -64,21 +65,21 @@ class DedupInputOptimizer(AbstractOptimizer):
root_node = node
break
if root_node == None:
break # end of convert
break # end of convert
else:
nodes_to_dedup = []
for node in input_nodes:
if node in nodes_to_skip:
continue
if self._check_deduplicate_by_node(root_node, node):
nodes_to_dedup.append(node)
nodes_to_dedup.append(node)
assert(len(nodes_to_dedup) >= 1)
if len(nodes_to_dedup) == 1:
assert(nodes_to_dedup[0] == root_node)
nodes_to_skip.add(root_node)
else:
dedup_node = DedupInputNode(logical_plan.logical_graph, \
logical_plan.lp_model._uid(), nodes_to_dedup)._register()
dedup_node = DedupInputNode(logical_plan.logical_graph,
logical_plan.lp_model._uid(), nodes_to_dedup)._register()
for edge in logical_plan.logical_graph.edges:
if edge.head in nodes_to_dedup:
edge.head = dedup_node

Просмотреть файл

@ -1,10 +0,0 @@
from .base_optimizer import BaseOptimizer
from .logical_plan import LogicalPlan
class WeightSharingOptimizer(BaseOptimizer):
def __init__(self) -> None:
pass
def convert(self, logical_plan: LogicalPlan) -> None:
pass

Просмотреть файл

@ -1,27 +1,31 @@
import dataclasses
import logging
import time
from dataclasses import dataclass
from pathlib import Path
from subprocess import Popen
from threading import Thread
from typing import Any, List, Optional
from typing import Any, Optional
from ..experiment import Experiment, TrainingServiceConfig
from ..experiment import launcher, rest
from ..experiment import Experiment, TrainingServiceConfig, launcher, rest
from ..experiment.config.base import ConfigBase, PathLike
from ..experiment.config import util
from ..experiment.pipe import Pipe
from .graph import Model
from .utils import get_records
from .integration import RetiariiAdvisor
from .converter.graph_gen import convert_to_graph
from .mutator import LayerChoiceMutator, InputChoiceMutator
from .converter import convert_to_graph
from .mutator import Mutator, LayerChoiceMutator, InputChoiceMutator
from .trainer.interface import BaseTrainer
from .strategies.strategy import BaseStrategy
_logger = logging.getLogger(__name__)
@dataclass(init=False)
class RetiariiExeConfig(ConfigBase):
experiment_name: Optional[str] = None
search_space: Any = '' # TODO: remove
search_space: Any = '' # TODO: remove
trial_command: str = 'python3 -m nni.retiarii.trial_entry'
trial_code_directory: PathLike = '.'
trial_concurrency: int
@ -52,6 +56,7 @@ class RetiariiExeConfig(ConfigBase):
def _validation_rules(self):
return _validation_rules
_canonical_rules = {
'trial_code_directory': util.canonical_path,
'max_experiment_duration': lambda value: f'{util.parse_time(value)}s' if value is not None else None,
@ -70,8 +75,8 @@ _validation_rules = {
class RetiariiExperiment(Experiment):
def __init__(self, base_model: 'nn.Module', trainer: 'BaseTrainer',
applied_mutators: List['Mutator'], strategy: 'BaseStrategy'):
def __init__(self, base_model: Model, trainer: BaseTrainer,
applied_mutators: Mutator, strategy: BaseStrategy):
self.config: RetiariiExeConfig = None
self.port: Optional[int] = None
@ -139,7 +144,7 @@ class RetiariiExperiment(Experiment):
debug
Whether to start in debug mode.
"""
# FIXME:
# FIXME:
if debug:
logging.getLogger('nni').setLevel(logging.DEBUG)
@ -189,4 +194,4 @@ class RetiariiExperiment(Experiment):
if self.port is None:
raise RuntimeError('Experiment is not running')
resp = rest.get(self.port, '/check-status')
return resp['status']
return resp['status']

Просмотреть файл

@ -5,7 +5,6 @@ Model representation.
import copy
from enum import Enum
import json
from collections import defaultdict
from typing import (Any, Dict, List, Optional, Tuple, Union, overload)
from .operation import Cell, Operation, _IOPseudoOperation
@ -329,12 +328,12 @@ class Graph:
Returns nodes whose operation is specified typed.
"""
return [node for node in self.hidden_nodes if node.operation.type == operation_type]
def get_node_by_id(self, id: int) -> Optional['Node']:
def get_node_by_id(self, node_id: int) -> Optional['Node']:
"""
Returns the node which has specified name; or returns `None` if no node has this name.
"""
found = [node for node in self.nodes if node.id == id]
found = [node for node in self.nodes if node.id == node_id]
return found[0] if found else None
def get_nodes_by_label(self, label: str) -> List['Node']:
@ -365,7 +364,8 @@ class Graph:
curr_nodes.append(successor)
for key in node_to_fanin:
assert node_to_fanin[key] == 0, '{}, fanin: {}, predecessor: {}, edges: {}, fanin: {}, keys: {}'.format(key,
assert node_to_fanin[key] == 0, '{}, fanin: {}, predecessor: {}, edges: {}, fanin: {}, keys: {}'.format(
key,
node_to_fanin[key],
key.predecessors[0],
self.edges,
@ -587,6 +587,7 @@ class Node:
ret['label'] = self.label
return ret
class Edge:
"""
A tensor, or "data flow", between two nodes.

Просмотреть файл

@ -1,17 +1,14 @@
import logging
import threading
from typing import *
from typing import Any, Callable
import json_tricks
import nni
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
from nni.runtime.protocol import send, CommandType
from nni.runtime.protocol import CommandType, send
from nni.utils import MetricType
from . import utils
from .graph import MetricData
_logger = logging.getLogger('nni.msg_dispatcher_base')
@ -44,6 +41,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
final_metric_callback
"""
def __init__(self):
super(RetiariiAdvisor, self).__init__()
register_advisor(self) # register the current advisor as the "global only" advisor
@ -88,28 +86,28 @@ class RetiariiAdvisor(MsgDispatcherBase):
'parameters': parameters,
'parameter_source': 'algorithm'
}
_logger.info('New trial sent: {}'.format(new_trial))
_logger.info('New trial sent: %s', new_trial)
send(CommandType.NewTrialJob, json_tricks.dumps(new_trial))
if self.send_trial_callback is not None:
self.send_trial_callback(parameters) # pylint: disable=not-callable
return self.parameters_count
def handle_request_trial_jobs(self, num_trials):
_logger.info('Request trial jobs: {}'.format(num_trials))
_logger.info('Request trial jobs: %s', num_trials)
if self.request_trial_jobs_callback is not None:
self.request_trial_jobs_callback(num_trials) # pylint: disable=not-callable
def handle_update_search_space(self, data):
_logger.info('Received search space: {}'.format(data))
_logger.info('Received search space: %s', data)
self.search_space = data
def handle_trial_end(self, data):
_logger.info('Trial end: {}'.format(data)) # do nothing
_logger.info('Trial end: %s', data)
self.trial_end_callback(json_tricks.loads(data['hyper_params'])['parameter_id'], # pylint: disable=not-callable
data['event'] == 'SUCCEEDED')
def handle_report_metric_data(self, data):
_logger.info('Metric reported: {}'.format(data))
_logger.info('Metric reported: %s', data)
if data['type'] == MetricType.REQUEST_PARAMETER:
raise ValueError('Request parameter not supported')
elif data['type'] == MetricType.PERIODICAL:

Просмотреть файл

@ -13,6 +13,7 @@ class Sampler:
"""
Handles `Mutator.choice()` calls.
"""
def choice(self, candidates: List[Choice], mutator: 'Mutator', model: Model, index: int) -> Choice:
raise NotImplementedError()
@ -35,6 +36,7 @@ class Mutator:
For certain mutator subclasses, strategy or sampler can use `Mutator.dry_run()` to predict choice candidates.
# Method names are open for discussion.
"""
def __init__(self, sampler: Optional[Sampler] = None):
self.sampler: Optional[Sampler] = sampler
self._cur_model: Optional[Model] = None
@ -77,7 +79,6 @@ class Mutator:
self.sampler = sampler_backup
return recorder.recorded_candidates, new_model
def mutate(self, model: Model) -> None:
"""
Abstract method to be implemented by subclass.
@ -105,6 +106,7 @@ class _RecorderSampler(Sampler):
# the following is for inline mutation
class LayerChoiceMutator(Mutator):
def __init__(self, node_name: str, candidates: List):
super().__init__()
@ -118,6 +120,7 @@ class LayerChoiceMutator(Mutator):
chosen_cand = self.candidates[chosen_index]
target.update_operation(chosen_cand['type'], chosen_cand['parameters'])
class InputChoiceMutator(Mutator):
def __init__(self, node_name: str, n_chosen: int):
super().__init__()
@ -129,4 +132,4 @@ class InputChoiceMutator(Mutator):
candidates = [i for i in range(self.n_chosen)]
chosen = self.choice(candidates)
target.update_operation('__torch__.nni.retiarii.nn.pytorch.nn.ChosenInputs',
{'chosen': chosen})
{'chosen': chosen})

Просмотреть файл

@ -1,8 +1,9 @@
import inspect
import logging
from typing import Any, List
import torch
import torch.nn as nn
from typing import (Any, Tuple, List, Optional)
from ...utils import add_record
@ -10,7 +11,7 @@ _logger = logging.getLogger(__name__)
__all__ = [
'LayerChoice', 'InputChoice', 'Placeholder',
'Module', 'Sequential', 'ModuleList', # TODO: 'ModuleDict', 'ParameterList', 'ParameterDict',
'Module', 'Sequential', 'ModuleList', # TODO: 'ModuleDict', 'ParameterList', 'ParameterDict',
'Identity', 'Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d',
'ConvTranspose2d', 'ConvTranspose3d', 'Threshold', 'ReLU', 'Hardtanh', 'ReLU6',
'Sigmoid', 'Tanh', 'Softmax', 'Softmax2d', 'LogSoftmax', 'ELU', 'SELU', 'CELU', 'GLU', 'GELU', 'Hardshrink',
@ -30,7 +31,7 @@ __all__ = [
'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer',
#'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d',
#'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d',
#'Unflatten', 'SiLU', 'TripletMarginWithDistanceLoss', 'ChannelShuffle',
#'Unflatten', 'SiLU', 'TripletMarginWithDistanceLoss', 'ChannelShuffle',
'Flatten', 'Hardsigmoid', 'Hardswish'
]
@ -57,9 +58,10 @@ class InputChoice(nn.Module):
if n_candidates or choose_from or return_mask:
_logger.warning('input arguments `n_candidates`, `choose_from` and `return_mask` are deprecated!')
def forward(self, candidate_inputs: List['Tensor']) -> 'Tensor':
def forward(self, candidate_inputs: List[torch.Tensor]) -> torch.Tensor:
# fake return
return torch.tensor(candidate_inputs)
return torch.tensor(candidate_inputs) # pylint: disable=not-callable
class ValueChoice:
"""
@ -67,6 +69,7 @@ class ValueChoice:
when instantiating a pytorch module.
TODO: can also be used in training approach
"""
def __init__(self, candidate_values: List[Any]):
self.candidate_values = candidate_values
@ -81,6 +84,7 @@ class Placeholder(nn.Module):
def forward(self, x):
return x
class ChosenInputs(nn.Module):
def __init__(self, chosen: int):
super().__init__()
@ -92,20 +96,24 @@ class ChosenInputs(nn.Module):
# the following are pytorch modules
class Module(nn.Module):
def __init__(self):
super(Module, self).__init__()
class Sequential(nn.Sequential):
def __init__(self, *args):
add_record(id(self), {})
super(Sequential, self).__init__(*args)
class ModuleList(nn.ModuleList):
def __init__(self, *args):
add_record(id(self), {})
super(ModuleList, self).__init__(*args)
def wrap_module(original_class):
orig_init = original_class.__init__
argname_list = list(inspect.signature(original_class).parameters.keys())
@ -115,14 +123,15 @@ def wrap_module(original_class):
full_args = {}
full_args.update(kws)
for i, arg in enumerate(args):
full_args[argname_list[i]] = args[i]
full_args[argname_list[i]] = arg
add_record(id(self), full_args)
orig_init(self, *args, **kws) # Call the original __init__
orig_init(self, *args, **kws) # Call the original __init__
original_class.__init__ = __init__ # Set the class' __init__ to the new one
original_class.__init__ = __init__ # Set the class' __init__ to the new one
return original_class
# TODO: support different versions of pytorch
Identity = wrap_module(nn.Identity)
Linear = wrap_module(nn.Linear)

Просмотреть файл

@ -4,12 +4,14 @@ from . import debug_configs
__all__ = ['Operation', 'Cell']
def _convert_name(name: str) -> str:
"""
Convert the names using separator '.' to valid variable name in code
"""
return name.replace('.', '__')
class Operation:
"""
Calculation logic of a graph node.
@ -152,6 +154,7 @@ class PyTorchOperation(Operation):
else:
raise RuntimeError(f'unsupported operation type: {self.type} ? {self._to_class_name()}')
class TensorFlowOperation(Operation):
def _to_class_name(self) -> str:
return 'K.layers.' + self.type
@ -191,6 +194,7 @@ class Cell(PyTorchOperation):
framework
No real usage. Exists for compatibility with base class.
"""
def __init__(self, cell_name: str, parameters: Dict[str, Any] = {}):
self.type = '_cell'
self.cell_name = cell_name
@ -207,6 +211,7 @@ class _IOPseudoOperation(Operation):
The benefit is that users no longer need to verify `Node.operation is not None`,
especially in static type checking.
"""
def __init__(self, type_name: str, io_names: List = None):
assert type_name.startswith('_')
super(_IOPseudoOperation, self).__init__(type_name, {}, True)

Просмотреть файл

@ -1,7 +1,8 @@
from ..operation import TensorFlowOperation
class Conv2D(TensorFlowOperation):
def __init__(self, type_name, parameters, _internal):
if 'padding' not in parameters:
parameters['padding'] = 'same'
super().__init__(type_name, parameters, _internal)
super().__init__(type_name, parameters, _internal)

Просмотреть файл

@ -1,5 +1,6 @@
from ..operation import PyTorchOperation
class relu(PyTorchOperation):
def to_init_code(self, field):
return ''
@ -17,6 +18,7 @@ class Flatten(PyTorchOperation):
assert len(inputs) == 1
return f'{output} = {inputs[0]}.view({inputs[0]}.size(0), -1)'
class ToDevice(PyTorchOperation):
def to_init_code(self, field):
return ''

Просмотреть файл

@ -1,8 +1,12 @@
import abc
from typing import List
from ..graph import Model
from ..mutator import Mutator
class BaseStrategy(abc.ABC):
@abc.abstractmethod
def run(self, base_model: 'Model', applied_mutators: List['Mutator']) -> None:
def run(self, base_model: Model, applied_mutators: List[Mutator]) -> None:
pass

Просмотреть файл

@ -1,16 +1,13 @@
import json
import logging
import random
import os
from .. import Model, submit_models, wait_models
from .. import Sampler
from .. import Sampler, submit_models, wait_models
from .strategy import BaseStrategy
from ...algorithms.hpo.hyperopt_tuner.hyperopt_tuner import HyperoptTuner
_logger = logging.getLogger(__name__)
class TPESampler(Sampler):
def __init__(self, optimize_mode='minimize'):
self.tpe_tuner = HyperoptTuner('tpe', optimize_mode)
@ -37,6 +34,7 @@ class TPESampler(Sampler):
self.index += 1
return chosen
class TPEStrategy(BaseStrategy):
def __init__(self):
self.tpe_sampler = TPESampler()
@ -55,7 +53,7 @@ class TPEStrategy(BaseStrategy):
while True:
model = base_model
_logger.info('apply mutators...')
_logger.info('mutators: {}'.format(applied_mutators))
_logger.info('mutators: %s', str(applied_mutators))
self.tpe_sampler.generate_samples(self.model_id)
for mutator in applied_mutators:
_logger.info('mutate model...')
@ -66,6 +64,6 @@ class TPEStrategy(BaseStrategy):
wait_models(model)
self.tpe_sampler.receive_result(self.model_id, model.metric)
self.model_id += 1
_logger.info('Strategy says:', model.metric)
except Exception as e:
_logger.info('Strategy says: %s', model.metric)
except Exception:
_logger.error(logging.exception('message'))

Просмотреть файл

@ -1,6 +1,5 @@
import abc
import inspect
from typing import *
from typing import Any
class BaseTrainer(abc.ABC):

Просмотреть файл

@ -1,5 +1,4 @@
import abc
from typing import *
from typing import Any, List, Dict, Tuple
import numpy as np
import torch
@ -42,6 +41,7 @@ def get_default_transform(dataset: str) -> Any:
# unsupported dataset, return None
return None
@register_trainer()
class PyTorchImageClassificationTrainer(BaseTrainer):
"""
@ -94,7 +94,7 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
self._dataloader = DataLoader(
self._dataset, **(dataloader_kwargs or {}))
def _accuracy(self, input, target):
def _accuracy(self, input, target): # pylint: disable=redefined-builtin
_, predict = torch.max(input.data, 1)
correct = predict.eq(target.data).cpu().sum().item()
return correct / input.size(0)
@ -176,7 +176,7 @@ class PyTorchMultiModelTrainer(BaseTrainer):
dataloader = DataLoader(dataset, **(dataloader_kwargs or {}))
self._datasets.append(dataset)
self._dataloaders.append(dataloader)
if m['use_output']:
optimizer_cls = m['optimizer_cls']
optimizer_kwargs = m['optimizer_kwargs']
@ -186,7 +186,7 @@ class PyTorchMultiModelTrainer(BaseTrainer):
name_prefix = '_'.join(name.split('_')[:2])
if m_header == name_prefix:
one_model_params.append(param)
optimizer = getattr(torch.optim, optimizer_cls)(one_model_params, **(optimizer_kwargs or {}))
self._optimizers.append(optimizer)
@ -206,7 +206,7 @@ class PyTorchMultiModelTrainer(BaseTrainer):
x, y = self.training_step_before_model(batch, batch_idx, f'cuda:{idx}')
xs.append(x)
ys.append(y)
y_hats = self.multi_model(*xs)
if len(ys) != len(xs):
raise ValueError('len(ys) should be equal to len(xs)')
@ -230,13 +230,12 @@ class PyTorchMultiModelTrainer(BaseTrainer):
if self.max_steps and batch_idx >= self.max_steps:
return
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> Dict[str, Any]:
x, y = self.training_step_before_model(batch, batch_idx)
y_hat = self.model(x)
return self.training_step_after_model(x, y, y_hat)
def training_step_before_model(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, device = None):
def training_step_before_model(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, device=None):
x, y = batch
if device:
x, y = x.cuda(torch.device(device)), y.cuda(torch.device(device))
@ -259,4 +258,4 @@ class PyTorchMultiModelTrainer(BaseTrainer):
def validation_step_after_model(self, x, y, y_hat):
acc = self._accuracy(y_hat, y)
return {'val_acc': acc}
return {'val_acc': acc}

Просмотреть файл

@ -6,7 +6,6 @@ import logging
import torch
import torch.nn as nn
from nni.nas.pytorch.mutables import LayerChoice
from ..interface import BaseOneShotTrainer
from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice

Просмотреть файл

@ -86,8 +86,8 @@ class ReinforceController(nn.Module):
self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.v_attn = nn.Linear(self.lstm_size, 1, bias=False)
self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1)
self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]),
requires_grad=False) # pylint: disable=not-callable
self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), # pylint: disable=not-callable
requires_grad=False)
assert entropy_reduction in ['sum', 'mean'], 'Entropy reduction must be one of sum and mean.'
self.entropy_reduction = torch.sum if entropy_reduction == 'sum' else torch.mean
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='none')

Просмотреть файл

@ -16,7 +16,7 @@ _logger = logging.getLogger(__name__)
def _get_mask(sampled, total):
multihot = [i == sampled or (isinstance(sampled, list) and i in sampled) for i in range(total)]
return torch.tensor(multihot, dtype=torch.bool)
return torch.tensor(multihot, dtype=torch.bool) # pylint: disable=not-callable
class PathSamplingLayerChoice(nn.Module):
@ -44,9 +44,9 @@ class PathSamplingLayerChoice(nn.Module):
def forward(self, *args, **kwargs):
assert self.sampled is not None, 'At least one path needs to be sampled before fprop.'
if isinstance(self.sampled, list):
return sum([getattr(self, self.op_names[i])(*args, **kwargs) for i in self.sampled])
return sum([getattr(self, self.op_names[i])(*args, **kwargs) for i in self.sampled]) # pylint: disable=not-an-iterable
else:
return getattr(self, self.op_names[self.sampled])(*args, **kwargs)
return getattr(self, self.op_names[self.sampled])(*args, **kwargs) # pylint: disable=invalid-sequence-index
def __len__(self):
return len(self.op_names)
@ -76,7 +76,7 @@ class PathSamplingInputChoice(nn.Module):
def forward(self, input_tensors):
if isinstance(self.sampled, list):
return sum([input_tensors[t] for t in self.sampled])
return sum([input_tensors[t] for t in self.sampled]) # pylint: disable=not-an-iterable
else:
return input_tensors[self.sampled]

Просмотреть файл

@ -123,13 +123,13 @@ class AverageMeter:
return fmtstr.format(**self.__dict__)
def _replace_module_with_type(root_module, init_fn, type, modules):
def _replace_module_with_type(root_module, init_fn, type_name, modules):
if modules is None:
modules = []
def apply(m):
for name, child in m.named_children():
if isinstance(child, type):
if isinstance(child, type_name):
setattr(m, name, init_fn(child))
modules.append((child.key, getattr(m, name)))
else:

Просмотреть файл

@ -1,19 +1,24 @@
from collections import defaultdict
import inspect
from collections import defaultdict
from typing import Any
def import_(target: str, allow_none: bool = False) -> 'Any':
def import_(target: str, allow_none: bool = False) -> Any:
if target is None:
return None
path, identifier = target.rsplit('.', 1)
module = __import__(path, globals(), locals(), [identifier])
return getattr(module, identifier)
_records = {}
def get_records():
global _records
return _records
def add_record(key, value):
"""
"""
@ -22,6 +27,7 @@ def add_record(key, value):
assert key not in _records, '{} already in _records'.format(key)
_records[key] = value
def _register_module(original_class):
orig_init = original_class.__init__
argname_list = list(inspect.signature(original_class).parameters.keys())
@ -31,14 +37,15 @@ def _register_module(original_class):
full_args = {}
full_args.update(kws)
for i, arg in enumerate(args):
full_args[argname_list[i]] = args[i]
full_args[argname_list[i]] = arg
add_record(id(self), full_args)
orig_init(self, *args, **kws) # Call the original __init__
orig_init(self, *args, **kws) # Call the original __init__
original_class.__init__ = __init__ # Set the class' __init__ to the new one
original_class.__init__ = __init__ # Set the class' __init__ to the new one
return original_class
def register_module():
"""
Register a module.
@ -68,14 +75,15 @@ def _register_trainer(original_class):
if isinstance(args[i], Module):
# ignore the base model object
continue
full_args[argname_list[i]] = args[i]
full_args[argname_list[i]] = arg
add_record(id(self), {'modulename': full_class_name, 'args': full_args})
orig_init(self, *args, **kws) # Call the original __init__
orig_init(self, *args, **kws) # Call the original __init__
original_class.__init__ = __init__ # Set the class' __init__ to the new one
original_class.__init__ = __init__ # Set the class' __init__ to the new one
return original_class
def register_trainer():
def _register(cls):
m = _register_trainer(
@ -84,8 +92,10 @@ def register_trainer():
return _register
_last_uid = defaultdict(int)
def uid(namespace: str = 'default') -> int:
_last_uid[namespace] += 1
return _last_uid[namespace]

Просмотреть файл

@ -41,7 +41,7 @@ jobs:
python3 -m pip install --upgrade pygments
python3 -m pip install --upgrade torch>=1.7.0+cpu torchvision>=0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
python3 -m pip install --upgrade tensorflow
python3 -m pip install --upgrade gym onnx peewee thop
python3 -m pip install --upgrade gym onnx peewee thop graphviz
python3 -m pip install sphinx==1.8.3 sphinx-argparse==0.2.5 sphinx-markdown-tables==0.0.9 sphinx-rtd-theme==0.4.2 sphinxcontrib-websupport==1.1.0 recommonmark==0.5.0 nbsphinx
sudo apt-get install swig -y
python3 -m pip install -e .[SMAC,BOHB]