[retiarii] fix experiment does not exit after done (#4916)

This commit is contained in:
QuanluZhang 2022-06-10 13:55:56 +08:00 коммит произвёл GitHub
Родитель c299e57674
Коммит 2bc984412c
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 87 добавлений и 8 удалений

Просмотреть файл

@ -46,8 +46,10 @@ class MsgDispatcherBase(Recoverable):
self._channel.connect()
self.default_command_queue = Queue()
self.assessor_command_queue = Queue()
self.default_worker = threading.Thread(target=self.command_queue_worker, args=(self.default_command_queue,))
self.assessor_worker = threading.Thread(target=self.command_queue_worker, args=(self.assessor_command_queue,))
# here daemon should be True, because their parent thread is configured as daemon to enable smooth exit of NAS experiment.
# if daemon is not set, these threads will block the daemon effect of their parent thread.
self.default_worker = threading.Thread(target=self.command_queue_worker, args=(self.default_command_queue,), daemon=True)
self.assessor_worker = threading.Thread(target=self.command_queue_worker, args=(self.assessor_command_queue,), daemon=True)
self.worker_exceptions = []
def run(self):

Просмотреть файл

@ -0,0 +1,76 @@
import argparse
import os
import sys
import pytorch_lightning as pl
import pytest
from subprocess import Popen
from nni.retiarii import strategy
from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment
from .test_oneshot import _mnist_net
pytestmark = pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs')
def test_multi_trial():
evaluator_kwargs = {
'max_epochs': 1
}
to_test = [
# (model, evaluator)
_mnist_net('simple', evaluator_kwargs),
_mnist_net('simple_value_choice', evaluator_kwargs),
_mnist_net('value_choice', evaluator_kwargs),
_mnist_net('repeat', evaluator_kwargs),
_mnist_net('custom_op', evaluator_kwargs),
]
for base_model, evaluator in to_test:
search_strategy = strategy.Random()
exp = RetiariiExperiment(base_model, evaluator, strategy=search_strategy)
exp_config = RetiariiExeConfig('local')
exp_config.experiment_name = 'mnist_unittest'
exp_config.trial_concurrency = 1
exp_config.max_trial_number = 1
exp_config.training_service.use_active_gpu = False
exp.run(exp_config, 8080)
assert isinstance(exp.export_top_models()[0], dict)
exp.stop()
python_script = """
from nni.retiarii import strategy
from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment
from test_oneshot import _mnist_net
base_model, evaluator = _mnist_net('simple', {'max_epochs': 1})
search_strategy = strategy.Random()
exp = RetiariiExperiment(base_model, evaluator, strategy=search_strategy)
exp_config = RetiariiExeConfig('local')
exp_config.experiment_name = 'mnist_unittest'
exp_config.trial_concurrency = 1
exp_config.max_trial_number = 1
exp_config.training_service.use_active_gpu = False
exp.run(exp_config, 8080)
assert isinstance(exp.export_top_models()[0], dict)
"""
@pytest.mark.timeout(600)
def test_exp_exit_without_stop():
script_name = 'tmp_multi_trial.py'
with open(script_name, 'w') as f:
f.write(python_script)
proc = Popen([sys.executable, script_name])
proc.wait()
os.remove(script_name)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--exp', type=str, default='all', metavar='E',
help='experiment to run, default = all')
args = parser.parse_args()
if args.exp == 'all':
test_multi_trial()
test_exp_exit_without_stop()
else:
globals()[f'test_{args.exp}']()

Просмотреть файл

@ -7,6 +7,7 @@ from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import Dataset, RandomSampler
import nni
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import strategy, model_wrapper, basic_unit
from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment
@ -216,13 +217,13 @@ def _mnist_net(type_, evaluator_kwargs):
raise ValueError(f'Unsupported type: {type_}')
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = MNIST('data/mnist', train=True, download=True, transform=transform)
train_dataset = nni.trace(MNIST)('data/mnist', train=True, download=True, transform=transform)
# Multi-GPU combined dataloader will break this subset sampler. Expected though.
train_random_sampler = RandomSampler(train_dataset, True, int(len(train_dataset) / 20))
train_loader = DataLoader(train_dataset, 64, sampler=train_random_sampler)
valid_dataset = MNIST('data/mnist', train=False, download=True, transform=transform)
valid_random_sampler = RandomSampler(valid_dataset, True, int(len(valid_dataset) / 20))
valid_loader = DataLoader(valid_dataset, 64, sampler=valid_random_sampler)
train_random_sampler = nni.trace(RandomSampler)(train_dataset, True, int(len(train_dataset) / 20))
train_loader = nni.trace(DataLoader)(train_dataset, 64, sampler=train_random_sampler)
valid_dataset = nni.trace(MNIST)('data/mnist', train=False, download=True, transform=transform)
valid_random_sampler = nni.trace(RandomSampler)(valid_dataset, True, int(len(valid_dataset) / 20))
valid_loader = nni.trace(DataLoader)(valid_dataset, 64, sampler=valid_random_sampler)
evaluator = Classification(train_dataloader=train_loader, val_dataloaders=valid_loader, **evaluator_kwargs)
return base_model, evaluator