Fix a few bugs in Retiarii and upgrade Dockerfile (#3713)

This commit is contained in:
Yuge Zhang 2021-06-03 14:48:56 +08:00 коммит произвёл GitHub
Родитель 7eaf9c2476
Коммит a284f71d11
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
10 изменённых файлов: 85 добавлений и 40 удалений

Просмотреть файл

@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
FROM nvidia/cuda:9.2-cudnn7-runtime-ubuntu18.04
FROM nvidia/cuda:10.2-cudnn8-runtime-ubuntu18.04
ARG NNI_RELEASE
@ -44,7 +44,7 @@ RUN ln -s python3 /usr/bin/python
RUN python3 -m pip install --upgrade pip==20.2.4 setuptools==50.3.2
# numpy 1.14.3 scipy 1.1.0
RUN python3 -m pip --no-cache-dir install numpy==1.14.3 scipy==1.1.0
RUN python3 -m pip --no-cache-dir install numpy==1.19.5 scipy==1.6.3
#
# TensorFlow
@ -52,15 +52,14 @@ RUN python3 -m pip --no-cache-dir install numpy==1.14.3 scipy==1.1.0
RUN python3 -m pip --no-cache-dir install tensorflow==2.3.1
#
# Keras 2.1.6
# Keras
#
RUN python3 -m pip --no-cache-dir install Keras==2.1.6
RUN python3 -m pip --no-cache-dir install Keras==2.4.0
#
# PyTorch
#
RUN python3 -m pip --no-cache-dir install torch==1.6.0
RUN python3 -m pip install torchvision==0.7.0
RUN python3 -m pip --no-cache-dir install torch==1.7.1 torchvision==0.8.2 pytorch-lightning==1.3.3
#
# sklearn 0.24.1
@ -70,7 +69,7 @@ RUN python3 -m pip --no-cache-dir install scikit-learn==0.24.1
#
# pandas==0.23.4 lightgbm==2.2.2
#
RUN python3 -m pip --no-cache-dir install pandas==0.23.4 lightgbm==2.2.2
RUN python3 -m pip --no-cache-dir install pandas==1.1 lightgbm==2.2.2
#
# Install NNI

Просмотреть файл

@ -8,10 +8,13 @@ If you are experiencing issues with TorchScript, or the generated model code by
This will come as the default execution engine in future version of Retiarii.
Two steps are needed to enable this engine now.
Three steps are needed to enable this engine now.
1. Add ``@nni.retiarii.model_wrapper`` decorator outside the whole PyTorch model.
2. Add ``config.execution_engine = 'py'`` to ``RetiariiExeConfig``.
3. If you need to export top models, formatter needs to be set to ``dict``. Exporting ``code`` won't work with this engine.
.. note:: You should always use ``super().__init__()` instead of ``super(MyNetwork, self).__init__()`` in the PyTorch model, because the latter one has issues with model wrapper.
``@basic_unit`` and ``serializer``
----------------------------------

Просмотреть файл

@ -4,22 +4,23 @@ import nni.retiarii.nn.pytorch as nn
import nni.retiarii.strategy as strategy
import nni.retiarii.evaluator.pytorch.lightning as pl
import torch.nn.functional as F
from nni.retiarii import serialize
from nni.retiarii import serialize, model_wrapper
from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment, debug_mutated_model
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
# uncomment this for python execution engine
# @model_wrapper
class Net(nn.Module):
def __init__(self, hidden_size):
super(Net, self).__init__()
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.LayerChoice([
nn.Linear(4*4*50, hidden_size),
nn.Linear(4*4*50, hidden_size, bias=False)
])
], label='fc1_choice')
self.fc2 = nn.Linear(hidden_size, 10)
def forward(self, x):
@ -55,8 +56,13 @@ if __name__ == '__main__':
exp_config.trial_concurrency = 2
exp_config.max_trial_number = 2
exp_config.training_service.use_active_gpu = False
export_formatter = 'code'
# uncomment this for python execution engine
# exp_config.execution_engine = 'py'
# export_formatter = 'dict'
exp.run(exp_config, 8081 + random.randint(0, 100))
print('Final model:')
for model_code in exp.export_top_models():
for model_code in exp.export_top_models(formatter=export_formatter):
print(model_code)

Просмотреть файл

@ -162,6 +162,8 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import nni.retiarii.nn.pytorch
{}
{}

Просмотреть файл

@ -30,7 +30,7 @@ class PythonGraphData:
class PurePythonExecutionEngine(BaseExecutionEngine):
@classmethod
def pack_model_data(cls, model: Model) -> Any:
mutation = {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model.history}
mutation = get_mutation_dict(model)
graph_data = PythonGraphData(get_importable_name(model.python_class, relocate_module=True),
model.python_init_params, mutation, model.evaluator)
return graph_data
@ -51,3 +51,7 @@ def _unpack_if_only_one(ele: List[Any]):
if len(ele) == 1:
return ele[0]
return ele
def get_mutation_dict(model: Model):
return {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model.history}

Просмотреть файл

@ -29,6 +29,7 @@ from nni.tools.nnictl.command_utils import kill_command
from ..codegen import model_to_pytorch_script
from ..converter import convert_to_graph
from ..execution import list_models, set_execution_engine
from ..execution.python import get_mutation_dict
from ..graph import Model, Evaluator
from ..integration import RetiariiAdvisor
from ..mutator import Mutator
@ -317,7 +318,7 @@ class RetiariiExperiment(Experiment):
"""
Export several top performing models.
For one-shot algorithms, only top-1 is supported. For others, ``optimize_mode`` asnd ``formater`` is
For one-shot algorithms, only top-1 is supported. For others, ``optimize_mode`` and ``formatter`` are
available for customization.
top_k : int
@ -326,8 +327,12 @@ class RetiariiExperiment(Experiment):
``maximize`` or ``minimize``. Not supported by one-shot algorithms.
``optimize_mode`` is likely to be removed and defined in strategy in future.
formatter : str
Only model code is supported for now. Not supported by one-shot algorithms.
Support ``code`` and ``dict``. Not supported by one-shot algorithms.
If ``code``, the python code of model will be returned.
If ``dict``, the mutation history will be returned.
"""
if formatter == 'code':
assert self.config.execution_engine != 'py', 'You should use `dict` formatter when using Python execution engine.'
if isinstance(self.trainer, BaseOneShotTrainer):
assert top_k == 1, 'Only support top_k is 1 for now.'
return self.trainer.export()
@ -335,9 +340,11 @@ class RetiariiExperiment(Experiment):
all_models = filter(lambda m: m.metric is not None, list_models())
assert optimize_mode in ['maximize', 'minimize']
all_models = sorted(all_models, key=lambda m: m.metric, reverse=optimize_mode == 'maximize')
assert formatter == 'code', 'Export formatter other than "code" is not supported yet.'
assert formatter in ['code', 'dict'], 'Export formatter other than "code" and "dict" is not supported yet.'
if formatter == 'code':
return [model_to_pytorch_script(model) for model in all_models[:top_k]]
elif formatter == 'dict':
return [get_mutation_dict(model) for model in all_models[:top_k]]
def retrain_model(self, model):
"""

Просмотреть файл

@ -1,6 +1,7 @@
# This file might cause import error for those who didn't install RL-related dependencies
import logging
from multiprocessing.pool import ThreadPool
import gym
import numpy as np
@ -9,6 +10,7 @@ import torch.nn as nn
from gym import spaces
from tianshou.data import to_torch
from tianshou.env.worker import EnvWorker
from .utils import get_targeted_model
from ..graph import ModelStatus
@ -18,6 +20,41 @@ from ..execution import submit_models, wait_models
_logger = logging.getLogger(__name__)
class MultiThreadEnvWorker(EnvWorker):
def __init__(self, env_fn):
self.env = env_fn()
self.pool = ThreadPool(processes=1)
super().__init__(env_fn)
def __getattr__(self, key):
return getattr(self.env, key)
def reset(self):
return self.env.reset()
@staticmethod
def wait(*args, **kwargs):
raise NotImplementedError('Async collect is not supported yet.')
def send_action(self, action) -> None:
# self.result is actually a handle
self.result = self.pool.apply_async(self.env.step, (action,))
def get_result(self):
return self.result.get()
def seed(self, seed):
super().seed(seed)
return self.env.seed(seed)
def render(self, **kwargs):
return self.env.render(**kwargs)
def close_env(self) -> None:
self.pool.terminate()
return self.env.close()
class ModelEvaluationEnv(gym.Env):
def __init__(self, base_model, mutators, search_space):
self.base_model = base_model
@ -107,7 +144,7 @@ class Actor(nn.Module):
# to take care of choices with different number of options
mask = torch.arange(self.action_dim).expand(len(out), self.action_dim) >= obs['action_dim'].unsqueeze(1)
out[mask.to(out.device)] = float('-inf')
return nn.functional.softmax(out), kwargs.get('state', None)
return nn.functional.softmax(out, dim=-1), kwargs.get('state', None)
class Critic(nn.Module):

Просмотреть файл

@ -8,10 +8,10 @@ from ..execution import query_available_resources
try:
has_tianshou = True
import torch
from tianshou.data import AsyncCollector, Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import BaseVectorEnv
from tianshou.policy import BasePolicy, PPOPolicy # pylint: disable=unused-import
from ._rl_impl import ModelEvaluationEnv, Preprocessor, Actor, Critic
from ._rl_impl import ModelEvaluationEnv, MultiThreadEnvWorker, Preprocessor, Actor, Critic
except ImportError:
has_tianshou = False
@ -25,8 +25,6 @@ class PolicyBasedRL(BaseStrategy):
This is a wrapper of algorithms provided in tianshou (PPO by default),
and can be easily customized with other algorithms that inherit ``BasePolicy`` (e.g., REINFORCE [1]_).
Note that RL algorithms are known to have issues on Windows and MacOS. They will be supported in future.
Parameters
----------
max_collect : int
@ -36,12 +34,6 @@ class PolicyBasedRL(BaseStrategy):
After each collect, trainer will sample batch from replay buffer and do the update. Default: 20.
policy_fn : function
Takes ``ModelEvaluationEnv`` as input and return a policy. See ``_default_policy_fn`` for an example.
asynchronous : bool
If true, in each step, collector won't wait for all the envs to complete.
This should generally not affect the result, but might affect the efficiency. Note that a slightly more trials
than expected might be collected if this is enabled.
If asynchronous is false, collector will wait for all parallel environments to complete in each step.
See ``tianshou.data.AsyncCollector`` for more details.
References
----------
@ -51,7 +43,7 @@ class PolicyBasedRL(BaseStrategy):
"""
def __init__(self, max_collect: int = 100, trial_per_collect = 20,
policy_fn: Optional[Callable[['ModelEvaluationEnv'], 'BasePolicy']] = None, asynchronous: bool = True):
policy_fn: Optional[Callable[['ModelEvaluationEnv'], 'BasePolicy']] = None):
if not has_tianshou:
raise ImportError('`tianshou` is required to run RL-based strategy. '
'Please use "pip install tianshou" to install it beforehand.')
@ -59,7 +51,6 @@ class PolicyBasedRL(BaseStrategy):
self.policy_fn = policy_fn or self._default_policy_fn
self.max_collect = max_collect
self.trial_per_collect = trial_per_collect
self.asynchronous = asynchronous
@staticmethod
def _default_policy_fn(env):
@ -77,13 +68,8 @@ class PolicyBasedRL(BaseStrategy):
env_fn = lambda: ModelEvaluationEnv(base_model, applied_mutators, search_space)
policy = self.policy_fn(env_fn())
if self.asynchronous:
# wait for half of the env complete in each step
env = SubprocVectorEnv([env_fn for _ in range(concurrency)], wait_num=int(concurrency * 0.5))
collector = AsyncCollector(policy, env, VectorReplayBuffer(20000, len(env)))
else:
env = SubprocVectorEnv([env_fn for _ in range(concurrency)])
collector = Collector(policy, env, VectorReplayBuffer(20000, len(env)))
env = BaseVectorEnv([env_fn for _ in range(concurrency)], MultiThreadEnvWorker)
collector = Collector(policy, env, VectorReplayBuffer(20000, len(env)))
for cur_collect in range(1, self.max_collect + 1):
_logger.info('Collect [%d] Running...', cur_collect)

Просмотреть файл

@ -3,6 +3,8 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import nni.retiarii.nn.pytorch
import torch

Просмотреть файл

@ -141,7 +141,6 @@ def test_evolution():
_reset_execution_engine()
@pytest.mark.skipif(sys.platform in ('win32', 'darwin'), reason='Does not run on Windows and MacOS')
def test_rl():
rl = strategy.PolicyBasedRL(max_collect=2, trial_per_collect=10)
engine = MockExecutionEngine(failure_prob=0.2)
@ -150,7 +149,7 @@ def test_rl():
wait_models(*engine.models)
_reset_execution_engine()
rl = strategy.PolicyBasedRL(max_collect=2, trial_per_collect=10, asynchronous=False)
rl = strategy.PolicyBasedRL(max_collect=2, trial_per_collect=10)
engine = MockExecutionEngine(failure_prob=0.2)
_reset_execution_engine(engine)
rl.run(*_get_model_and_mutators())