[Retiarii] cross-graph optimization: device placement and input deduplication (#3202)

This commit is contained in:
Zhenhua Han 2021-07-30 19:11:02 +08:00 коммит произвёл GitHub
Родитель 6645bd332d
Коммит f2f58dbb55
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
29 изменённых файлов: 1872 добавлений и 601 удалений

4
.gitignore поставляемый
Просмотреть файл

@ -10,6 +10,8 @@
/ts/nni_manager/exp_profile.json
/ts/nni_manager/metrics.json
/ts/nni_manager/trial_jobs.json
/test/ut/retiarii/_debug_graph_data.json
/test/ut/retiarii/out.tmp
# Logs
logs
@ -105,5 +107,3 @@ venv.bak/
.vscode
.vs
.history
generated/
test/ut/retiarii/_debug_graph_data.json

2
dependencies/recommended.txt поставляемый
Просмотреть файл

@ -8,7 +8,7 @@ torch == 1.9.0+cpu ; sys_platform != "darwin"
torch == 1.9.0 ; sys_platform == "darwin"
torchvision == 0.10.0+cpu ; sys_platform != "darwin"
torchvision == 0.10.0 ; sys_platform == "darwin"
pytorch-lightning >= 1.1.1
pytorch-lightning >= 1.2.8
onnx
peewee
graphviz

2
dependencies/recommended_gpu.txt поставляемый
Просмотреть файл

@ -5,7 +5,7 @@ tensorflow
keras == 2.4.3
torch == 1.9.0+cu111
torchvision == 0.10.0+cu111
pytorch-lightning >= 1.1.1
pytorch-lightning >= 1.2.8
onnx
peewee
graphviz

Просмотреть файл

@ -9,7 +9,9 @@ class GPUDevice:
status: Literal['idle', 'busy', 'unknown'] = 'idle'
def __eq__(self, o) -> bool:
return self.node_id == o.node_id and self.gpu_id == o.gpu_id
if isinstance(o, GPUDevice):
return self.node_id == o.node_id and self.gpu_id == o.gpu_id
return False
def __lt__(self, o) -> bool:
if self.node_id < o.node_id:
@ -23,7 +25,10 @@ class GPUDevice:
return "{Environment %s, GPU %d, Status %s}" % (self.node_id, self.gpu_id, self.status)
def __hash__(self) -> int:
return hash(self.node_id + '_' + self.gpu_id)
return hash(self.node_id + '_' + str(self.gpu_id))
def set_status(self, status):
self.status = status
def device_repr(self,):
return f"cuda:{self.gpu_id}"

Просмотреть файл

@ -115,7 +115,7 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
node_code = node.operation.to_init_code(_remove_prefix(node.name, graph_name))
if node_code is not None:
if placement and node in placement and len(node_code) > 0:
node_codes.append(f"{node_code}.to('{placement[node].device}')")
node_codes.append(f"{node_code}.to('{placement[node].device_repr()}')")
else:
node_codes.append(node_code)

Просмотреть файл

Просмотреть файл

@ -0,0 +1,106 @@
from typing import Any, Union, Optional, List
import torch
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.plugins import Plugin
from pytorch_lightning.plugins.environments import ClusterEnvironment
from ....serializer import serialize_cls
class BypassPlugin(TrainingTypePlugin):
""" Plugin that handles communication on a single device. """
def __init__(self, device: str):
super().__init__()
self.device: str = device
self.global_rank = 0
self.local_rank = 0
self.world_size = 1
def connect(self, model: torch.nn.Module) -> torch.nn.Module:
self._model = model
self.model_to_device()
return self.model
@property
def on_tpu(self) -> bool:
return False
@property
def on_gpu(self) -> bool:
return "cuda" in self.device and torch.cuda.is_available()
def reduce(self, tensor: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) -> Union[Any, torch.Tensor]:
"""
Reduces a tensor from several distributed processes to one aggregated tensor.
As this plugin only operates with a single device, the reduction is simply the identity.
Args:
tensor: the tensor to sync and reduce
*args: ignored
**kwargs: ignored
Return:
the unmodified input as reduction is not needed for single process operation
"""
return tensor
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
"""Perform a all_gather on all processes """
return tensor
@property
def root_device(self) -> torch.device:
return torch.device(self.device)
def model_to_device(self) -> None:
# bypass device placement from pytorch lightning
pass
def setup(self, model: torch.nn.Module) -> torch.nn.Module:
self.model_to_device()
return self.model
@property
def is_global_zero(self) -> bool:
return True
def barrier(self, *args, **kwargs) -> None:
pass
def broadcast(self, obj: object, src: int = 0) -> object:
return obj
def get_accelerator_connector(
num_processes: int = 1,
tpu_cores: Optional[Union[List[int], str, int]] = None,
distributed_backend: Optional[str] = None,
auto_select_gpus: bool = False,
gpus: Optional[Union[List[int], str, int]] = None,
num_nodes: int = 1,
sync_batchnorm: bool = False,
benchmark: bool = False,
replace_sampler_ddp: bool = True,
deterministic: bool = False,
precision: int = 32,
amp_backend: str = 'native',
amp_level: str = 'O2',
plugins: Optional[Union[List[Union[Plugin, ClusterEnvironment, str]], Plugin, ClusterEnvironment, str]] = None):
return AcceleratorConnector(
num_processes, tpu_cores, distributed_backend, auto_select_gpus, gpus, num_nodes, sync_batchnorm, benchmark,
replace_sampler_ddp, deterministic, precision, amp_backend, amp_level, plugins
)
@serialize_cls
class BypassAccelerator(Accelerator):
def __init__(self, precision_plugin=None, device="cpu"):
if precision_plugin is None:
precision_plugin = get_accelerator_connector().precision_plugin
# pylint: disable=abstract-class-instantiated
super().__init__(precision_plugin=precision_plugin, training_type_plugin=BypassPlugin(device))

Просмотреть файл

@ -0,0 +1,222 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import warnings
from typing import Dict, List, Optional, Union
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
from torch.utils.data import DataLoader
import nni
from ..lightning import LightningModule, _AccuracyWithLogits, Lightning
from .trainer import Trainer
from ....serializer import serialize_cls
@serialize_cls
class _MultiModelSupervisedLearningModule(LightningModule):
def __init__(self, criterion: nn.Module, metrics: Dict[str, pl.metrics.Metric],
n_models: int = 0,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
super().__init__()
self.save_hyperparameters('criterion', 'optimizer', 'learning_rate', 'weight_decay')
self.criterion = criterion()
self.criterion_cls = criterion
self.optimizer = optimizer
self.metrics = nn.ModuleDict({name: cls() for name, cls in metrics.items()})
self.n_models = n_models
def forward(self, x):
y_hat = self.model(x)
return y_hat
def training_step(self, batch, batch_idx):
x, y = batch
multi_y_hat = self(x)
if isinstance(multi_y_hat, tuple):
assert len(multi_y_hat) == self.n_models
else:
assert self.n_models == 1
multi_y_hat = [multi_y_hat]
multi_loss = []
for idx, y_hat in enumerate(multi_y_hat):
loss = self.criterion(y_hat.to("cpu"), y.to("cpu"))
self.log(f'train_loss_{idx}', loss, prog_bar=True)
for name, metric in self.metrics.items():
self.log(f'train_{idx}_' + name, metric(y_hat.to("cpu"), y.to("cpu")), prog_bar=True)
multi_loss.append(loss)
return sum(multi_loss)
def validation_step(self, batch, batch_idx):
x, y = batch
multi_y_hat = self(x)
if isinstance(multi_y_hat, tuple):
assert len(multi_y_hat) == self.n_models
else:
assert self.n_models == 1
multi_y_hat = [multi_y_hat]
for idx, y_hat in enumerate(multi_y_hat):
self.log(f'val_loss_{idx}', self.criterion(y_hat.to("cpu"), y.to("cpu")), prog_bar=True)
for name, metric in self.metrics.items():
self.log(f'val_{idx}_' + name, metric(y_hat.to("cpu"), y.to("cpu")), prog_bar=True)
def test_step(self, batch, batch_idx):
x, y = batch
multi_y_hat = self(x)
if isinstance(multi_y_hat, tuple):
assert len(multi_y_hat) == self.n_models
else:
assert self.n_models == 1
multi_y_hat = [multi_y_hat]
for idx, y_hat in enumerate(multi_y_hat):
self.log(f'test_loss_{idx}', self.criterion(y_hat.to("cpu"), y.to("cpu")), prog_bar=True)
for name, metric in self.metrics.items():
self.log(f'test_{idx}_' + name, metric(y_hat.to("cpu"), y.to("cpu")), prog_bar=True)
def configure_optimizers(self):
return self.optimizer(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)
def on_validation_epoch_end(self):
nni.report_intermediate_result(self._get_validation_metrics())
def teardown(self, stage):
if stage == 'fit':
nni.report_final_result(self._get_validation_metrics())
def _get_validation_metrics(self):
# TODO: split metric of multiple models?
if len(self.metrics) == 1:
metric_name = next(iter(self.metrics))
ret = []
for idx in range(self.n_models):
ret.append(self.trainer.callback_metrics[f'val_{idx}_' + metric_name].item())
return ret
else:
warnings.warn('Multiple metrics without "default" is not supported by current framework.')
return {name: self.trainer.callback_metrics['val_' + name].item() for name in self.metrics}
class MultiModelSupervisedLearningModule(_MultiModelSupervisedLearningModule):
"""
Lightning Module of SupervisedLearning for Cross-Graph Optimization.
Users who needs cross-graph optimization should use this module.
Parameters
----------
criterion : nn.Module
Class for criterion module (not an instance). default: ``nn.CrossEntropyLoss``
learning_rate : float
Learning rate. default: 0.001
weight_decay : float
L2 weight decay. default: 0
optimizer : Optimizer
Class for optimizer (not an instance). default: ``Adam``
"""
def __init__(self, criterion: nn.Module, metrics: Dict[str, pl.metrics.Metric],
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
super().__init__(criterion, metrics, learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
@serialize_cls
class _ClassificationModule(MultiModelSupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
super().__init__(criterion, {'acc': _AccuracyWithLogits},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
class Classification(Lightning):
"""
Trainer that is used for classification.
Parameters
----------
criterion : nn.Module
Class for criterion module (not an instance). default: ``nn.CrossEntropyLoss``
learning_rate : float
Learning rate. default: 0.001
weight_decay : float
L2 weight decay. default: 0
optimizer : Optimizer
Class for optimizer (not an instance). default: ``Adam``
train_dataloders : DataLoader
Used in ``trainer.fit()``. A PyTorch DataLoader with training samples.
If the ``lightning_module`` has a predefined train_dataloader method this will be skipped.
val_dataloaders : DataLoader or List of DataLoader
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
trainer_kwargs : dict
Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/trainer.html>`__ for details.
"""
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
**trainer_kwargs):
module = _ClassificationModule(criterion=criterion, learning_rate=learning_rate,
weight_decay=weight_decay, optimizer=optimizer)
super().__init__(module, Trainer(use_cgo=True, **trainer_kwargs),
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
@serialize_cls
class _RegressionModule(MultiModelSupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.MSELoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
super().__init__(criterion, {'mse': pl.metrics.MeanSquaredError},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
class Regression(Lightning):
"""
Trainer that is used for regression.
Parameters
----------
criterion : nn.Module
Class for criterion module (not an instance). default: ``nn.MSELoss``
learning_rate : float
Learning rate. default: 0.001
weight_decay : float
L2 weight decay. default: 0
optimizer : Optimizer
Class for optimizer (not an instance). default: ``Adam``
train_dataloders : DataLoader
Used in ``trainer.fit()``. A PyTorch DataLoader with training samples.
If the ``lightning_module`` has a predefined train_dataloader method this will be skipped.
val_dataloaders : DataLoader or List of DataLoader
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
trainer_kwargs : dict
Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/trainer.html>`__ for details.
"""
def __init__(self, criterion: nn.Module = nn.MSELoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
**trainer_kwargs):
module = _RegressionModule(criterion=criterion, learning_rate=learning_rate,
weight_decay=weight_decay, optimizer=optimizer)
super().__init__(module, Trainer(use_cgo=True, **trainer_kwargs),
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)

Просмотреть файл

@ -0,0 +1,31 @@
import pytorch_lightning as pl
from ....serializer import serialize_cls
from .accelerator import BypassAccelerator
@serialize_cls
class Trainer(pl.Trainer):
"""
Trainer for cross-graph optimization.
Parameters
----------
use_cgo : bool
Whether cross-graph optimization (CGO) is used.
If it is True, CGO will manage device placement.
Any device placement from pytorch lightning will be bypassed.
default: False
trainer_kwargs : dict
Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/trainer.html>`__ for details.
"""
def __init__(self, use_cgo=False, **trainer_kwargs):
if use_cgo:
if "accelerator" in trainer_kwargs:
raise ValueError("accelerator should not be set when cross-graph optimization is enabled.")
trainer_kwargs['accelerator'] = BypassAccelerator(device='cpu')
super().__init__(**trainer_kwargs)

Просмотреть файл

@ -12,6 +12,12 @@ import torch.optim as optim
from torch.utils.data import DataLoader
import nni
try:
import nni.retiarii.evaluator.pytorch.cgo.trainer as cgo_trainer
cgo_import_failed = False
except ImportError:
cgo_import_failed = True
from ...graph import Evaluator
from ...serializer import serialize_cls
@ -36,7 +42,6 @@ class LightningModule(pl.LightningModule):
Trainer = serialize_cls(pl.Trainer)
DataLoader = serialize_cls(DataLoader)
class Lightning(Evaluator):
"""
Delegate the whole training to PyTorch Lightning.
@ -67,7 +72,11 @@ class Lightning(Evaluator):
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None):
assert isinstance(lightning_module, LightningModule), f'Lightning module must be an instance of {__name__}.LightningModule.'
assert isinstance(trainer, Trainer), f'Trainer must be imported from {__name__}.'
if cgo_import_failed:
assert isinstance(trainer, Trainer), f'Trainer must be imported from {__name__}'
else:
assert isinstance(trainer, Trainer) or isinstance(trainer, cgo_trainer.Trainer), \
f'Trainer must be imported from {__name__} or nni.retiarii.evaluator.pytorch.cgo.trainer'
assert _check_dataloader(train_dataloader), f'Wrong dataloader type. Try import DataLoader from {__name__}.'
assert _check_dataloader(val_dataloaders), f'Wrong dataloader type. Try import DataLoader from {__name__}.'
self.module = lightning_module
@ -91,7 +100,21 @@ class Lightning(Evaluator):
return self.fit(model_cls)
def __eq__(self, other):
return self.function == other.function and self.arguments == other.arguments
eq_func = False
eq_args = False
if other is None:
return False
if hasattr(self, "function") and hasattr(other, "function"):
eq_func = (self.function == other.function)
elif not (hasattr(self, "function") or hasattr(other, "function")):
eq_func = True
if hasattr(self, "arguments") and hasattr(other, "arguments"):
eq_args = (self.arguments == other.arguments)
elif not (hasattr(self, "arguments") or hasattr(other, "arguments")):
eq_args = True
return eq_func and eq_args
def fit(self, model):
"""

Просмотреть файл

@ -2,14 +2,22 @@
# Licensed under the MIT license.
import logging
import os
import random
import string
import time
import threading
from typing import Iterable, List, Dict, Tuple
from nni.common.device import GPUDevice
from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo
from .. import codegen, utils
from ..graph import Model, ModelStatus, MetricData
from ..graph import Model, ModelStatus, MetricData, Node
from ..integration_api import send_trial, receive_trial_parameters, get_advisor
from .logical_optimizer.logical_plan import LogicalPlan, PhysicalDevice
from .logical_optimizer.logical_plan import LogicalPlan, AbstractLogicalNode
from .logical_optimizer.opt_dedup_input import DedupInputOptimizer
from ..evaluator.pytorch.lightning import Lightning
from ..evaluator.pytorch.cgo.evaluator import MultiModelSupervisedLearningModule, _MultiModelSupervisedLearningModule
from .base import BaseGraphData
@ -17,29 +25,93 @@ _logger = logging.getLogger(__name__)
class CGOExecutionEngine(AbstractExecutionEngine):
def __init__(self, devices=None, n_model_per_graph=4) -> None:
"""
The execution engine with Cross-Graph Optimization (CGO).
Only models using PyTorch Lighting and MultiModelSupervisedLearningModule as the evaluator can be optimized.
Otherwise, a model will be submitted independently without any cross-graph optimization.
Parameters
----------
devices : List[str] or List[GPUDevice]
Available devices for execution.
If a list of str is provided, it will build a list of GPUDevice in a server named ``single_server``
max_concurrency : int
The maximum number of trials to run concurrently.
batch_waiting_time: int
Seconds to wait for each batch of trial submission.
The trials within one batch could apply cross-graph optimization.
"""
def __init__(self, devices: List[GPUDevice] = None,
max_concurrency: int = None,
batch_waiting_time: int = 60,
) -> None:
self._listeners: List[AbstractGraphListener] = []
self._running_models: Dict[int, Model] = dict()
self.logical_plan_counter = 0
self.n_model_per_graph = n_model_per_graph
self.available_devices: List[GPUDevice] = []
self.max_concurrency: int = max_concurrency
for device in devices:
self.available_devices.append(device)
self.all_devices = self.available_devices.copy()
self._batch_waiting_time = batch_waiting_time # seconds to wait for all models in a batch to do cross-graph optimization
self._optimizers = [DedupInputOptimizer()]
self._original_models = {}
self._original_model_to_multi_model = {}
self.devices = [] if devices is None else devices
self._trial_to_original_models = {}
self._trial_used_devices: Dict[int, List[GPUDevice]] = {}
self._history: List[Model] = []
self._queuing_jobs: List[Model] = []
self._queue_lock = threading.Lock()
# register advisor callbacks
advisor = get_advisor()
advisor.send_trial_callback = self._send_trial_callback
advisor.request_trial_jobs_callback = self._request_trial_jobs_callback
# advisor.send_trial_callback = self._send_trial_callback
# advisor.request_trial_jobs_callback = self._request_trial_jobs_callback
advisor.trial_end_callback = self._trial_end_callback
advisor.intermediate_metric_callback = self._intermediate_metric_callback
advisor.final_metric_callback = self._final_metric_callback
self._stopped = False
self._consumer_thread = threading.Thread(target=self._consume_queue)
self._consumer_thread.start()
def join(self):
self._stopped = True
self._consumer_thread.join()
def add_optimizer(self, opt):
self._optimizers.append(opt)
def submit_models(self, *models: List[Model]) -> None:
curr_time = time.time()
_logger.info('%d models are submitted', len(models))
self._queue_lock.acquire()
self._queuing_jobs.extend([(curr_time, _) for _ in models])
self._queue_lock.release()
def _consume_queue(self):
# a thread to monitor self.queuing_jobs to consume them in batch
while not self._stopped:
if len(self._queuing_jobs) > 0:
curr_time = time.time()
self._queue_lock.acquire()
if (self.max_concurrency and len(self._queuing_jobs) >= self.max_concurrency):
self._submit_models_in_batch(*[_[1] for _ in self._queuing_jobs[:self.max_concurrency]])
self._queuing_jobs = self._queuing_jobs[self.max_concurrency:]
elif len(self.available_devices) <= len(self._queuing_jobs) or \
(curr_time - self._queuing_jobs[0][0] > self._batch_waiting_time):
self._submit_models_in_batch(*[_[1] for _ in self._queuing_jobs])
self._queuing_jobs = []
self._queue_lock.release()
time.sleep(1)
def _submit_models_in_batch(self, *models: List[Model]) -> None:
_logger.info('%d models are submitted in batch', len(models))
logical = self._build_logical(models)
for opt in self._optimizers:
@ -47,31 +119,51 @@ class CGOExecutionEngine(AbstractExecutionEngine):
phy_models_and_placements = self._assemble(logical)
for model, placement, grouped_models in phy_models_and_placements:
data = BaseGraphData(codegen.model_to_pytorch_script(model, placement=placement),
model.evaluator)
data = BaseGraphData(codegen.model_to_pytorch_script(model, placement=placement), model.evaluator)
trial_id = send_trial(data.dump())
# unique non-cpu devices used by the trial
self._trial_used_devices[trial_id] = list([_ for _ in set(placement.values()) if isinstance(_, GPUDevice)])
# currently, it is impossible for search strategy to submit models more than the number of available devices
for used_device in self._trial_used_devices[trial_id]:
self.available_devices.remove(used_device) # used_device must be in self.available_devices
self._running_models[trial_id] = model
self._trial_to_original_models[trial_id] = []
for m in grouped_models:
self._original_models[m.model_id] = m
self._original_model_to_multi_model[m.model_id] = model
self._running_models[send_trial(data.dump())] = model
# for model in models:
# data = BaseGraphData(codegen.model_to_pytorch_script(model),
# model.config['trainer_module'], model.config['trainer_kwargs'])
# self._running_models[send_trial(data.dump())] = model
self._trial_to_original_models[trial_id].append(m.model_id)
self._history.append(m)
def list_models(self) -> Iterable[Model]:
raise NotImplementedError
return self._history
def _assemble(self, logical_plan: LogicalPlan) -> List[Tuple[Model, Dict[Node, GPUDevice], List[Model]]]:
# try to use the available_devices first so that it can be launched as early as possible
# if free devices are not enough to assemble all models in one trial, try all devices
if len(self.available_devices) > 0:
grouped_models: List[Dict[Model, GPUDevice]] = AssemblePolicy().group(logical_plan, self.available_devices)
if len(self.available_devices) == 0 or len(grouped_models) > 1:
grouped_models: List[Dict[Model, GPUDevice]] = AssemblePolicy().group(logical_plan, self.all_devices)
def _assemble(self, logical_plan: LogicalPlan) -> List[Tuple[Model, PhysicalDevice]]:
# unique_models = set()
# for node in logical_plan.graph.nodes:
# if node.graph.model not in unique_models:
# unique_models.add(node.graph.model)
# return [m for m in unique_models]
grouped_models: List[Dict[Model, PhysicalDevice]] = AssemblePolicy().group(logical_plan)
phy_models_and_placements = []
for multi_model in grouped_models:
model, model_placement = logical_plan.assemble(multi_model)
assert isinstance(model.evaluator, Lightning), \
"cross-graph optimization only supports pytorch lighting as evaluator"
assert isinstance(model.evaluator.module, _MultiModelSupervisedLearningModule), \
"cross-graph optimization only support MultiModelSupervisedLearningModule"
# replace the module with a new instance whose n_models is set
# n_models must be set in __init__, otherwise it cannot be captured by serialize_cls
new_module_init_params = model.evaluator.module._init_parameters.copy()
# MultiModelSupervisedLearningModule hides n_models of _MultiModelSupervisedLearningModule from users
new_module_init_params['n_models'] = len(multi_model)
new_module = _MultiModelSupervisedLearningModule(**new_module_init_params)
model.evaluator.module = new_module
phy_models_and_placements.append((model, model_placement, multi_model.keys()))
return phy_models_and_placements
@ -85,13 +177,14 @@ class CGOExecutionEngine(AbstractExecutionEngine):
def register_graph_listener(self, listener: AbstractGraphListener) -> None:
self._listeners.append(listener)
def _send_trial_callback(self, paramater: dict) -> None:
for listener in self._listeners:
listener.on_resource_used(0) # FIXME: find the real resource id
# def _send_trial_callback(self, paramater: dict) -> None:
# if len(self.available_devices) == 0:
# _logger.warning('There is no available devices, but trial is submitted.')
# _logger.debug('Resource used. Remaining: %d', len(self.available_devices))
def _request_trial_jobs_callback(self, num_trials: int) -> None:
for listener in self._listeners:
listener.on_resource_available([0] * num_trials) # FIXME: find the real resource id
# def _request_trial_jobs_callback(self, num_trials: int) -> None:
# self.resources += num_trials
# _logger.info('on_resource_available: %d', self.resources)
def _trial_end_callback(self, trial_id: int, success: bool) -> None:
model = self._running_models[trial_id]
@ -108,31 +201,40 @@ class CGOExecutionEngine(AbstractExecutionEngine):
original_model.status = ModelStatus.Failed
for listener in self._listeners:
listener.on_training_end(original_model, success)
self.available_devices.extend(self._trial_used_devices[trial_id])
self.available_devices = sorted(list(set(self.available_devices)))
del self._running_models[trial_id]
def _intermediate_metric_callback(self, trial_id: int, metrics: MetricData) -> None:
# model = self._running_models[trial_id]
merged_metrics = dict(metrics)
merged_metrics = {}
for idx, _ in enumerate(metrics):
merged_metrics[self._trial_to_original_models[trial_id][idx]] = metrics[idx]
for model_id in merged_metrics:
int_model_id = int(model_id)
self._original_models[int_model_id].intermediate_metrics.append(merged_metrics[model_id])
# model.intermediate_metrics.append(metrics)
self._original_models[model_id].intermediate_metrics.append(merged_metrics[model_id])
for listener in self._listeners:
listener.on_intermediate_metric(self._original_models[int_model_id], merged_metrics[model_id])
listener.on_intermediate_metric(self._original_models[model_id], merged_metrics[model_id])
def _final_metric_callback(self, trial_id: int, metrics: MetricData) -> None:
merged_metrics = dict(metrics)
for model_id in merged_metrics:
int_model_id = int(model_id)
self._original_models[int_model_id].intermediate_metrics.append(merged_metrics[model_id])
# model.intermediate_metrics.append(metrics)
for listener in self._listeners:
listener.on_metric(self._original_models[int_model_id], merged_metrics[model_id])
_logger.debug(metrics)
if isinstance(metrics, float):
self._listeners[0].on_metric(self._running_models[trial_id], metrics)
else:
merged_metrics = {}
for idx, _ in enumerate(metrics):
merged_metrics[self._trial_to_original_models[trial_id][idx]] = metrics[idx]
for model_id in merged_metrics:
self._original_models[model_id].metric = merged_metrics[model_id]
for listener in self._listeners:
listener.on_metric(self._original_models[model_id], merged_metrics[model_id])
def query_available_resource(self) -> List[WorkerInfo]:
raise NotImplementedError # move the method from listener to here?
# the _queuing_jobs need to use available_devices first
return len(self.available_devices) - len(self._queuing_jobs)
def budget_exhausted(self) -> bool:
raise NotImplementedError
advisor = get_advisor()
return advisor.stopping
@classmethod
def trial_execute_graph(cls) -> None:
@ -141,20 +243,86 @@ class CGOExecutionEngine(AbstractExecutionEngine):
"""
graph_data = BaseGraphData.load(receive_trial_parameters())
_logger.info('CGO_ENGINE trial parameters received')
with open('_generated_model.py', 'w') as f:
random_str = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(6))
file_name = f'_generated_model/{random_str}.py'
os.makedirs(os.path.dirname(file_name), exist_ok=True)
with open(file_name, 'w') as f:
f.write(graph_data.model_script)
# with open('_debug_graph_data.json', 'w') as f:
# json.dump(graph_data.dump(), f)
trainer_cls = utils.import_(graph_data.training_module)
model_cls = utils.import_(f"_generated_model.{graph_data.training_kwargs['model_cls']}")
trainer_instance = trainer_cls(model_cls(), graph_data.training_kwargs)
trainer_instance.fit()
trainer_instance = graph_data.evaluator
model_cls = utils.import_(f'_generated_model.{random_str}._model')
trainer_instance.fit(model_cls())
os.remove(file_name)
def _remap_cuda_device(group_model: Dict[Model, GPUDevice]):
used_devices = {}
for m in group_model:
if group_model[m].node_id not in used_devices:
used_devices[group_model[m].node_id] = {}
if isinstance(group_model[m], GPUDevice):
if group_model[m].gpu_id not in used_devices[group_model[m].node_id]:
n_used_gpu_in_server = len(used_devices[group_model[m].node_id])
used_devices[group_model[m].node_id][group_model[m].gpu_id] = n_used_gpu_in_server
group_model[m].gpu_id = used_devices[group_model[m].node_id][group_model[m].gpu_id]
return group_model
class AssemblePolicy:
@staticmethod
def group(logical_plan):
def _is_related_node(model: Model, node: Node):
if isinstance(node, AbstractLogicalNode):
if model in node.related_models:
return True
else:
if model == node.graph.model:
return True
return False
@staticmethod
def _check_graph_connectivity(model: Model,
group_model: Dict[Model, GPUDevice],
logical_plan: LogicalPlan) -> bool:
for edge in logical_plan.logical_graph.edges:
if AssemblePolicy._is_related_node(model, edge.head) or \
AssemblePolicy._is_related_node(model, edge.tail):
for grouped_model in group_model:
if AssemblePolicy._is_related_node(grouped_model, edge.head) or \
AssemblePolicy._is_related_node(grouped_model, edge.tail):
return True
return False
@staticmethod
def _check_evaluator(new_model: Model, group_model: Dict[Model, GPUDevice]) -> bool:
if not (isinstance(new_model.evaluator, Lightning)
and isinstance(new_model.evaluator.module, MultiModelSupervisedLearningModule)):
return False
for m in group_model:
if not m.evaluator == new_model.evaluator:
return False
return True
@staticmethod
def group(logical_plan, available_devices):
# TODO: Packing multiple model in one GPU
# Currently, we only support one model per GPU
all_grouped_models = []
group_model = {}
assert(len(available_devices) > 0) # There should be at least 1 device, set in CGO_DEVICES
for idx, m in enumerate(logical_plan.models):
group_model[m] = PhysicalDevice('server', f'cuda:{idx}')
return [group_model]
# models in one group should
# (1) not use more GPUs than available_devices
# (2) be connected in the logical plan (independent models should be assembled in multiple groups)
# (3) use same MultiModelSupervisedLearningModule
if len(group_model) > 0 and \
(AssemblePolicy._check_graph_connectivity(m, group_model, logical_plan) == False or
AssemblePolicy._check_evaluator(m, group_model) == False):
all_grouped_models.append(_remap_cuda_device(group_model))
group_model = {}
group_model[m] = available_devices[idx % len(available_devices)]
if len(group_model) == len(available_devices) or \
idx == len(logical_plan.models) - 1:
all_grouped_models.append(_remap_cuda_device(group_model))
group_model = {}
return all_grouped_models

Просмотреть файл

@ -2,30 +2,30 @@
# Licensed under the MIT license.
import copy
from typing import Dict, Tuple, List, Any
from typing import Dict, Tuple, Any, Union
from nni.retiarii.utils import uid
from nni.common.device import GPUDevice
from ...graph import Cell, Edge, Graph, Model, Node
from ...operation import Operation, _IOPseudoOperation
class PhysicalDevice:
def __init__(self, server: str, device: str):
self.server = server
self.device = device
class CPUDevice:
def __init__(self, node_id):
self.node_id = node_id
self.device = 'cpu'
def __eq__(self, o) -> bool:
return self.server == o.server and self.device == o.device
def __hash__(self) -> int:
return hash(self.server + '_' + self.device)
def device_repr(self):
return "cpu"
class AbstractLogicalNode(Node):
def __init__(self, graph, node_id, name, operation, _internal=False):
super().__init__(graph, node_id, name, operation, _internal=_internal)
self.related_models = []
def assemble(self, multi_model_placement: Dict[Model, PhysicalDevice]) -> Tuple[Node, PhysicalDevice]:
def assemble(self, multi_model_placement: Dict[Model, GPUDevice]) -> Tuple[Node, GPUDevice]:
raise NotImplementedError
def _fork_to(self, graph: Graph):
@ -40,8 +40,7 @@ class LogicalGraph(Graph):
nodes_dump = {}
for node in self.hidden_nodes:
if isinstance(node, OriginNode):
nodes_dump[f"{node.original_graph.model.model_id}_{node.name}"] = node._dump(
)
nodes_dump[f"{node.original_graph.model.model_id}_{node.name}"] = node._dump()
else:
nodes_dump[f"{node.graph.model.model_id}_{node.name}"] = node._dump()
@ -93,7 +92,7 @@ class OriginNode(AbstractLogicalNode):
self.original_graph = original_graph
self.original_node = original_node
def assemble(self, multi_model_placement: Dict[Model, PhysicalDevice]) -> Tuple[Node, PhysicalDevice]:
def assemble(self, multi_model_placement: Dict[Model, GPUDevice]) -> Tuple[Node, GPUDevice]:
model_id = self.original_node.graph.model.model_id
new_node = Node(self.original_node.graph, self.original_node.id,
f"M_{model_id}_" +
@ -137,30 +136,32 @@ class LogicalPlan:
for edge in from_graph.edges:
new_head = id_to_new_node[edge.head.id]
new_tail = id_to_new_node[edge.tail.id]
Edge((new_head, edge.head_slot), (new_tail,
edge.tail_slot), _internal=True)._register()
Edge((new_head, edge.head_slot), (new_tail, edge.tail_slot), _internal=True)._register()
def assemble(self, multi_model_placement: Dict[Model, PhysicalDevice]) \
-> Tuple[Model, Dict[Node, PhysicalDevice], List[Model]]:
phy_model = Model(_internal=True) # self.lp_model.fork()
def assemble(self, multi_model_placement: Dict[Model, GPUDevice]) \
-> Tuple[Model, Dict[Node, Union[GPUDevice, CPUDevice]]]:
phy_model = Model(_internal=True)
phy_graph = self.lp_model.root_graph._fork_to(phy_model)
# Add a flag to mark multi-model in graph json.
# Multi-model has a list of training configs in kwargs['model_kwargs']
if len(multi_model_placement) > 1:
phy_model.evaluator.kwargs['is_multi_model'] = True
phy_model.evaluator.kwargs['model_cls'] = phy_graph.name
phy_model.evaluator.kwargs['model_kwargs'] = []
# FIXME: allow user to specify
phy_model.evaluator.module = 'nni.retiarii.trainer.pytorch.PyTorchMultiModelTrainer'
phy_graph._rename_graph(phy_graph.name, "_model")
# merge sub-graphs
for model in multi_model_placement:
if phy_model.evaluator is None and model.evaluator is not None:
phy_model.evaluator = model.evaluator
for graph_name in model.graphs:
if graph_name != model._root_graph_name:
model.graphs[graph_name]._fork_to(
new_graph = model.graphs[graph_name]._fork_to(
phy_model, name_prefix=f'M_{model.model_id}_')
# prefix of M_ of hidden_nodes name in non-root graphs is added here
for new_node in new_graph.hidden_nodes:
if isinstance(new_node.operation, Cell):
old_cell_name = new_node.operation.cell_name
new_node.operation = copy.deepcopy(new_node.operation)
new_node.operation.cell_name = f'M_{model.model_id}_{old_cell_name}'
assert(phy_model.evaluator is not None)
# When replace logical nodes, merge the training configs when
# input/output nodes are replaced.
evaluator_slot = {} # Model ID -> Slot ID
@ -169,6 +170,9 @@ class LogicalPlan:
# Replace all logical nodes to executable physical nodes
hidden_nodes = phy_graph.hidden_nodes.copy()
node_placements = {}
added_models = []
for node in hidden_nodes:
if isinstance(node, OriginNode):
model_id = node.original_graph.model.model_id
@ -185,12 +189,9 @@ class LogicalPlan:
if isinstance(new_node.operation, _IOPseudoOperation):
model_id = new_node.graph.model.model_id
if model_id not in evaluator_slot:
phy_model.evaluator.kwargs['model_kwargs'].append(new_node.graph.model.evaluator.kwargs.copy())
evaluator_slot[model_id] = len(phy_model.evaluator.kwargs['model_kwargs']) - 1
added_models.append(model_id)
evaluator_slot[model_id] = len(added_models) - 1
slot = evaluator_slot[model_id]
phy_model.evaluator.kwargs['model_kwargs'][slot]['model_id'] = model_id
phy_model.evaluator.kwargs['model_kwargs'][slot]['use_input'] = False
phy_model.evaluator.kwargs['model_kwargs'][slot]['use_output'] = False
else:
slot = evaluator_slot[model_id]
# If a model's inputs/outputs are not used in the multi-model
@ -199,37 +200,47 @@ class LogicalPlan:
# an input/output of a model is used in a multi-model
if new_node.operation.type == '_inputs':
input_slot_mapping[new_node] = slot
phy_model.evaluator.kwargs['model_kwargs'][slot]['use_input'] = True
if new_node.operation.type == '_outputs':
output_slot_mapping[new_node] = slot
phy_model.evaluator.kwargs['model_kwargs'][slot]['use_output'] = True
self.node_replace(node, new_node)
# name prefix of M_ of cells in hidden_nodes of root graphs is added here
# FIXME: merge this rename with non-root graph, only do once.
if isinstance(new_node.operation, Cell):
old_cell_name = new_node.operation.cell_name
new_node.operation = copy.deepcopy(new_node.operation)
new_node.operation.cell_name = f'M_{model_id}_{old_cell_name}'
node_placements[new_node] = placement
# input should be at CPU, move it to GPU first if necessary
if isinstance(new_node.operation, _IOPseudoOperation) and new_node.operation.type == '_inputs':
# hack: only support single_server
node_placements[new_node] = CPUDevice(node_id=placement.node_id)
else:
node_placements[new_node] = placement
node.remove()
# If two nodes are placed on different devices, use ToDevice op to copy the node
existing_edges = phy_graph.edges.copy()
# Avoid a node is copied multiple times on the same device
copied_op: Dict[Tuple(Node, PhysicalDevice), Node] = {}
copied_op: Dict[Tuple(Node, Union[GPUDevice, CPUDevice]), Node] = {}
for edge in existing_edges:
head_placement = node_placements[edge.head]
tail_placement = node_placements[edge.tail]
if head_placement != tail_placement:
if head_placement.server != tail_placement.server:
if head_placement.node_id != tail_placement.node_id:
raise ValueError('Cross-server placement is not supported.')
# Same server different devices
if (edge.head, tail_placement) in copied_op:
to_node = copied_op[(edge.head, tail_placement)]
else:
to_operation = Operation.new('ToDevice', {"device": tail_placement.device})
to_node = Node(phy_graph, uid(), edge.head.name + "_to_" + edge.tail.name, to_operation)._register()
dst_name = edge.head.name + "_to_" + edge.tail.name
to_operation = Operation.new(
'ToDevice', {
"device": tail_placement.device_repr(), "src": (
edge.head.name, edge.head_slot), "dst": dst_name})
to_node = Node(phy_graph, uid(), dst_name, to_operation)._register()
Edge((edge.head, edge.head_slot), (to_node, None), _internal=True)._register()
copied_op[(edge.head, tail_placement)] = to_node
edge.head = to_node

Просмотреть файл

@ -4,23 +4,28 @@
from typing import List, Dict, Tuple
from nni.retiarii.utils import uid
from nni.retiarii.evaluator.pytorch.cgo.evaluator import MultiModelSupervisedLearningModule
from nni.common.device import GPUDevice
from ...graph import Graph, Model, Node
from .interface import AbstractOptimizer
from .logical_plan import (AbstractLogicalNode, LogicalGraph, LogicalPlan,
OriginNode, PhysicalDevice)
OriginNode)
_supported_training_modules = ['nni.retiarii.trainer.pytorch.PyTorchImageClassificationTrainer']
_supported_evaluators = [MultiModelSupervisedLearningModule]
class DedupInputNode(AbstractLogicalNode):
def __init__(self, logical_graph: LogicalGraph, node_id: int,
nodes_to_dedup: List[Node], _internal=False):
super().__init__(logical_graph, node_id,
"Dedup_"+nodes_to_dedup[0].name,
"Dedup_" + nodes_to_dedup[0].name,
nodes_to_dedup[0].operation)
self.origin_nodes: List[OriginNode] = nodes_to_dedup.copy()
self.related_models = [_.original_graph.model for _ in self.origin_nodes]
def assemble(self, multi_model_placement: Dict[Model, PhysicalDevice]) -> Tuple[Node, PhysicalDevice]:
def assemble(self, multi_model_placement: Dict[Model, GPUDevice]) -> Tuple[Node, GPUDevice]:
for node in self.origin_nodes:
if node.original_graph.model in multi_model_placement:
new_node = Node(node.original_graph, node.id,
@ -41,6 +46,12 @@ class DedupInputOptimizer(AbstractOptimizer):
def __init__(self) -> None:
pass
def _check_supported_evaluator(self, evaluator):
for e in _supported_evaluators:
if isinstance(evaluator, e):
return True
return False
def _check_deduplicate_by_node(self, root_node, node_to_check):
if root_node == node_to_check:
return True
@ -48,7 +59,7 @@ class DedupInputOptimizer(AbstractOptimizer):
node_to_check.operation.type == '_inputs' and \
isinstance(root_node, OriginNode) and \
isinstance(node_to_check, OriginNode):
if root_node.original_graph.model.evaluator.module not in _supported_training_modules:
if self._check_supported_evaluator(root_node.original_graph.model.evaluator):
return False
if root_node.original_graph.model.evaluator == node_to_check.original_graph.model.evaluator:
return True
@ -68,7 +79,7 @@ class DedupInputOptimizer(AbstractOptimizer):
continue
root_node = node
break
if root_node == None:
if root_node is None:
break # end of convert
else:
nodes_to_dedup = []

Просмотреть файл

@ -50,8 +50,11 @@ class RetiariiExeConfig(ConfigBase):
trial_code_directory: PathLike = '.'
trial_concurrency: int
trial_gpu_number: int = 0
devices: Optional[List[Union[str, GPUDevice]]] = None
max_experiment_duration: Optional[str] = None
max_trial_number: Optional[int] = None
max_concurrency_cgo: Optional[int] = None
batch_waiting_time: Optional[int] = None
nni_manager_ip: Optional[str] = None
debug: bool = False
log_level: Optional[str] = None
@ -134,11 +137,12 @@ def preprocess_model(base_model, trainer, applied_mutators, full_ir=True, dummy_
if mutators is not None and applied_mutators:
raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, '
'do not use mutators when you use LayerChoice/InputChoice')
'do not use mutators when you use LayerChoice/InputChoice')
if mutators is not None:
applied_mutators = mutators
return base_model_ir, applied_mutators
def debug_mutated_model(base_model, trainer, applied_mutators):
"""
Locally run only one trial without launching an experiment for debug purpose, then exit.
@ -189,7 +193,7 @@ class RetiariiExperiment(Experiment):
self.strategy.run(base_model_ir, self.applied_mutators)
_logger.info('Strategy exit')
# TODO: find out a proper way to show no more trial message on WebUI
#self._dispatcher.mark_experiment_as_ending()
# self._dispatcher.mark_experiment_as_ending()
def start(self, port: int = 8080, debug: bool = False) -> None:
"""
@ -205,14 +209,18 @@ class RetiariiExperiment(Experiment):
"""
atexit.register(self.stop)
devices = self._construct_devices()
# we will probably need a execution engine factory to make this clean and elegant
if self.config.execution_engine == 'base':
from ..execution.base import BaseExecutionEngine
engine = BaseExecutionEngine()
elif self.config.execution_engine == 'cgo':
from ..execution.cgo_engine import CGOExecutionEngine
engine = CGOExecutionEngine(devices = devices)
# assert self.config.trial_gpu_number==1, "trial_gpu_number must be 1 to use CGOExecutionEngine"
assert self.config.batch_waiting_time is not None
devices = self._construct_devices()
engine = CGOExecutionEngine(devices,
max_concurrency=self.config.max_concurrency_cgo,
batch_waiting_time=self.config.batch_waiting_time)
elif self.config.execution_engine == 'py':
from ..execution.python import PurePythonExecutionEngine
engine = PurePythonExecutionEngine()
@ -315,7 +323,7 @@ class RetiariiExperiment(Experiment):
if self._dispatcher_thread is not None:
self._dispatcher.stopping = True
self._dispatcher_thread.join(timeout=1)
if self.id is not None:
nni.runtime.log.stop_experiment_log(self.id)
if self._proc is not None:

Просмотреть файл

@ -410,7 +410,7 @@ class Graph:
return self is other
def _fork_to(self, model: Model, name_prefix='') -> 'Graph':
new_graph = Graph(model, self.id, name_prefix+self.name, _internal=True)._register()
new_graph = Graph(model, self.id, name_prefix + self.name, _internal=True)._register()
# TODO: use node copy instead
new_graph.input_node.operation.io_names = self.input_node.operation.io_names
new_graph.output_node.operation.io_names = self.output_node.operation.io_names
@ -458,6 +458,11 @@ class Graph:
self.model.graphs[self.name] = self
return self
def _rename_graph(self, old_name, new_name):
self.model.graphs[old_name].name = new_name
self.model.graphs[new_name] = self.model.graphs[old_name]
del self.model.graphs[old_name]
@staticmethod
def _load(model: Model, name: str, ir: Any) -> 'Graph':
graph = Graph(model, uid(), name, _internal=True)

Просмотреть файл

@ -158,4 +158,4 @@ class RetiariiAdvisor(MsgDispatcherBase):
return value['default']
else:
return value
return value
return value

Просмотреть файл

@ -98,6 +98,10 @@ class PyTorchOperation(Operation):
if hasattr(subclass, '_ori_type_name') and \
subclass_name in subclass._ori_type_name:
return subclass
for subclass in cls.__subclasses__():
if hasattr(subclass, '_artificial_op_name') and \
subclass_name in subclass._artificial_op_name:
return subclass
return cls
@classmethod

Просмотреть файл

@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import (Any, List)
from typing import (Any, Dict, List)
import torch
@ -32,21 +32,27 @@ scalar_type_to_pytorch_type = [
'torch.bool', # 11
]
class NoOpIdentity(PyTorchOperation):
"""
this operator type is added by us
"""
_ori_type_name = ['noop_identity']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = {", ".join(inputs)}'
class ModuleOperator(PyTorchOperation):
_ori_type_name = ['ModuleOperator', 'shared']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = self.{field}({", ".join(inputs)})'
class FunctionalOperator(PyTorchOperation):
_ori_type_name = ['FunctionalOperator']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
func_name = self.type[len('Function.'):]
if not hasattr(torch.nn.functional, func_name):
@ -54,8 +60,10 @@ class FunctionalOperator(PyTorchOperation):
f'{func_name} is not in it.')
return f'{output} = F.{func_name}({", ".join(inputs)})'
class PrimConstant(PyTorchOperation):
_ori_type_name = ['prim::Constant']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
# TODO: refactor this part, maybe we can remove the code gen of prim::Constant
# TODO: deal with all the types
@ -75,63 +83,83 @@ class PrimConstant(PyTorchOperation):
else:
raise RuntimeError(f'unsupported type of prim::Constant: {self.parameters["type"]}')
class PrimListConstruct(PyTorchOperation):
_ori_type_name = ['prim::ListConstruct']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = [{", ".join(inputs)}]'
class PrimListUnpack(PyTorchOperation):
_ori_type_name = ['prim::ListUnpack']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = {inputs[0]}'
class PrimTupleConstruct(PyTorchOperation):
_ori_type_name = ['prim::TupleConstruct']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = ({", ".join(inputs)})'
class PrimTupleUnpack(PyTorchOperation):
_ori_type_name = ['prim::TupleUnpack']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
# have single output here, because the following code uses index to access the unpacked values
assert len(inputs) == 1
return f'{output} = {inputs[0]}'
class PrimGetAttr(PyTorchOperation):
_ori_type_name = ['prim::GetAttr']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
if self.parameters['value'] is not None:
return f"{output} = {self.parameters['value']}"
else:
return f"{output} = {self.parameters['input']}.{self.parameters['name']}"
class SimpleMember(PyTorchOperation):
_ori_type_name = ['prim::is_cuda', 'prim::data']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
member_name = self.type.split('::')[-1]
return f'{output} = {inputs[0]}.{member_name}'
class AtenContiguous(PyTorchOperation):
_ori_type_name = ['aten::contiguous']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
# defined in pytorch/c10/core/MemoryFormat.h
assert inputs_value[1] in [0, 1, 2]
return f'{output} = {inputs[0]}.contiguous(memory_format={mem_format[inputs_value[1]]})'
class AtenGetitem(PyTorchOperation):
_ori_type_name = ['aten::__getitem__']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
assert len(inputs) == 2
return f'{output} = {inputs[0]}[{inputs[1]}]'
class AtenAppend(PyTorchOperation):
_ori_type_name = ['aten::append']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
assert len(inputs) == 2
return f'_, {output} = {inputs[0]}.append({inputs[1]}), {inputs[0]}'
class MergedSlice(PyTorchOperation):
_ori_type_name = ['MergedSlice']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
if (len(inputs) - 1) % 4 == 0:
slices = []
@ -148,23 +176,30 @@ class MergedSlice(PyTorchOperation):
# the following Aten classes means these aten ops are not in torch.Tensor
class AtenBool(PyTorchOperation):
_ori_type_name = ['aten::Bool']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = bool({inputs[0]})'
class AtenNot(PyTorchOperation):
_ori_type_name = ['aten::__not__']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = not {inputs[0]}'
class AtenCat(PyTorchOperation):
_ori_type_name = ['aten::cat']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
assert len(inputs) == 2
return f'{output} = torch.cat({inputs[0]}, dim={inputs[1]})'
#====================================
# ====================================
class AtenTensors(PyTorchOperation):
_ori_type_name = ['aten::full', 'aten::full_like', 'aten::empty_like',
@ -209,20 +244,26 @@ class AtenTensors(PyTorchOperation):
else:
return f'{output} = {inputs[0]}.{op_name}({", ".join(args_list[1:])})'
#====================================
# ====================================
class AtenFloordiv(PyTorchOperation):
_ori_type_name = ['aten::floordiv']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = {inputs[0]} // {inputs[1]}'
class AtenLen(PyTorchOperation):
_ori_type_name = ['aten::len']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = len({inputs[0]})'
class AtenIntImplicit(PyTorchOperation):
_ori_type_name = ['aten::IntImplicit', 'aten::Float', 'aten::Int', 'aten::ScalarImplicit']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
if self.type.endswith('Implicit'):
return f'{output} = {inputs[0]}'
@ -231,11 +272,14 @@ class AtenIntImplicit(PyTorchOperation):
elif self.type == 'aten::Float':
return f'{output} = float({inputs[0]})'
class AtenIndex(PyTorchOperation):
_ori_type_name = ['aten::index']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = {inputs[0]}[{inputs[1]}]'
ManuallyChooseDef = {
'aten::flatten': [('start_dim', 'int', '0'), ('end_dim', 'int', '-1')],
'aten::split': [('split_size', 'int', 'None'), ('dim', 'int', '0')],
@ -248,21 +292,24 @@ ManuallyChooseDef = {
}
TensorOpExceptions = {
'aten::sub': lambda output, inputs: f'{output} = {inputs[0]} - {inputs[1]}', # example: x.size(1) - 3
'aten::add': lambda output, inputs: f'{output} = {inputs[0]} + {inputs[1]}' # example: input.shape[0] + 5
'aten::sub': lambda output, inputs: f'{output} = {inputs[0]} - {inputs[1]}', # example: x.size(1) - 3
'aten::add': lambda output, inputs: f'{output} = {inputs[0]} + {inputs[1]}' # example: input.shape[0] + 5
}
TorchOpExclude = ['aten::Size', 'aten::as_tensor', 'aten::device',
'aten::manual_seed', 'aten::quantized_gru', 'aten::quantized_lstm',
'aten::save', 'aten::tensor', 'aten::wait'
]
]
def _hidden(name):
return name.startswith('_') and not name.startswith('__')
def _emit_args(args):
# filter out the `out` argument here
return [(arg.name, str(arg.type), str(arg.default_value)) for arg in args] # if arg.name != 'out'
return [(arg.name, str(arg.type), str(arg.default_value)) for arg in args] # if arg.name != 'out'
def _get_tensor_ops():
def is_tensor_method(schema):
@ -291,6 +338,7 @@ def _get_tensor_ops():
return op_args.keys(), op_args
def _get_torch_ops():
torch_op_args = {}
for mod in torch.jit._builtins._modules_containing_builtins:
@ -316,6 +364,7 @@ def _get_torch_ops():
return torch_op_args.keys(), torch_op_args
def _get_torch_ops_exclude_tensor_ops():
tensor_op_names, _ = _get_tensor_ops()
torch_op_names, torch_ops = _get_torch_ops()
@ -330,6 +379,7 @@ def _get_torch_ops_exclude_tensor_ops():
return torch_exclude_ops.keys(), torch_exclude_ops
class TensorOps(PyTorchOperation):
"""
corresponding to _get_tensor_ops in torch.jit.supported_ops
@ -346,7 +396,7 @@ class TensorOps(PyTorchOperation):
name = ','.join([arg[0] for arg in each])
concated_names.append(name)
for i in range(len(concated_names) - 1):
if concated_names[i] != concated_names[i+1]:
if concated_names[i] != concated_names[i + 1]:
return False
return True
@ -383,6 +433,7 @@ class TensorOps(PyTorchOperation):
args_str = ', '.join([f'{name}={inputs[i+1]}' for i, (name, t, default) in enumerate(matched_args)])
return f'{output} = {inputs[0]}.{op_name}({args_str})'
class TorchOps(PyTorchOperation):
"""
corresponding to _get_nn_functional_ops in torch.jit.supported_ops
@ -400,7 +451,7 @@ class TorchOps(PyTorchOperation):
name = ','.join([arg[0] for arg in each])
concated_names.append(name)
for i in range(len(concated_names) - 1):
if concated_names[i] != concated_names[i+1]:
if concated_names[i] != concated_names[i + 1]:
return False
return True
@ -424,19 +475,36 @@ class TorchOps(PyTorchOperation):
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
matched_args = TorchOps._get_matched_args(self.type, inputs)
op_name = self.type.split('::')[-1]
args_str = ', '.join([f'{name}={inputs[i]}' if t.startswith('Optional[') else f'{inputs[i]}' \
for i, (name, t, default) in enumerate(matched_args)])
args_str = ', '.join([f'{name}={inputs[i]}' if t.startswith('Optional[') else f'{inputs[i]}'
for i, (name, t, default) in enumerate(matched_args)])
return f'{output} = torch.{op_name}({args_str})'
class AtenAvgpool2d(PyTorchOperation):
# NOTE: it is not included in the above aten ops for unkown reason
_ori_type_name = ['aten::avg_pool2d']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = F.avg_pool2d({", ".join(inputs)})'
class ToDevice(PyTorchOperation):
_artificial_op_name = "ToDevice"
def __init__(self, type_name: str, parameters: Dict[str, Any], _internal: bool = False):
self.type = "ToDevice"
self.device = parameters['device']
self.src = parameters['src']
self.dst = parameters['dst']
def __repr__(self):
return f'to("{self.device}")'
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {inputs[0]}.to("{self.device}")'
class AtenDet(PyTorchOperation):
# for torch 1.9
# NOTE: it is not included in the above aten ops, maybe because torch.det is alias for torch.linalg.det
_ori_type_name = ['aten::linalg_det']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = torch.det({inputs[0]})'
return f'{output} = torch.det({inputs[0]})'

Просмотреть файл

@ -0,0 +1,165 @@
from collections import OrderedDict
from typing import (List, Optional)
import torch
import torch.nn as torch_nn
#sys.path.append(str(Path(__file__).resolve().parents[2]))
import ops
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import basic_unit
@basic_unit
class AuxiliaryHead(nn.Module):
""" Auxiliary head in 2/3 place of network to let the gradient flow well """
def __init__(self, input_size, C, n_classes):
""" assuming input size 7x7 or 8x8 """
assert input_size in [7, 8]
super().__init__()
self.net = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=input_size - 5, padding=0, count_include_pad=False), # 2x2 out
nn.Conv2d(C, 128, kernel_size=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, kernel_size=2, bias=False), # 1x1 out
nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.linear = nn.Linear(768, n_classes)
def forward(self, x):
out = self.net(x)
out = out.view(out.size(0), -1) # flatten
logits = self.linear(out)
return logits
class Node(nn.Module):
def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect):
super().__init__()
self.ops = nn.ModuleList()
choice_keys = []
for i in range(num_prev_nodes):
stride = 2 if i < num_downsample_connect else 1
choice_keys.append("{}_p{}".format(node_id, i))
self.ops.append(
nn.LayerChoice([
ops.PoolBN('max', channels, 3, stride, 1, affine=False),
ops.PoolBN('avg', channels, 3, stride, 1, affine=False),
nn.Identity() if stride == 1 else ops.FactorizedReduce(channels, channels, affine=False),
ops.SepConv(channels, channels, 3, stride, 1, affine=False),
ops.SepConv(channels, channels, 5, stride, 2, affine=False),
ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False),
ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False)
]))
self.drop_path = ops.DropPath()
self.input_switch = nn.InputChoice(n_candidates=num_prev_nodes, n_chosen=2)
def forward(self, prev_nodes: List['Tensor']) -> 'Tensor':
#assert self.ops.__len__() == len(prev_nodes)
#out = [op(node) for op, node in zip(self.ops, prev_nodes)]
out = []
for i, op in enumerate(self.ops):
out.append(op(prev_nodes[i]))
#out = [self.drop_path(o) if o is not None else None for o in out]
return self.input_switch(out)
class Cell(nn.Module):
def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction):
super().__init__()
self.reduction = reduction
self.n_nodes = n_nodes
# If previous cell is reduction cell, current input size does not match with
# output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing.
if reduction_p:
self.preproc0 = ops.FactorizedReduce(channels_pp, channels, affine=False)
else:
self.preproc0 = ops.StdConv(channels_pp, channels, 1, 1, 0, affine=False)
self.preproc1 = ops.StdConv(channels_p, channels, 1, 1, 0, affine=False)
# generate dag
self.mutable_ops = nn.ModuleList()
for depth in range(2, self.n_nodes + 2):
self.mutable_ops.append(Node("{}_n{}".format("reduce" if reduction else "normal", depth),
depth, channels, 2 if reduction else 0))
def forward(self, s0, s1):
# s0, s1 are the outputs of previous previous cell and previous cell, respectively.
tensors = [self.preproc0(s0), self.preproc1(s1)]
new_tensors = []
for node in self.mutable_ops:
tmp = tensors + new_tensors
cur_tensor = node(tmp)
new_tensors.append(cur_tensor)
output = torch.cat(new_tensors, dim=1)
return output
class CNN(nn.Module):
def __init__(self, input_size, in_channels, channels, n_classes, n_layers, n_nodes=4,
stem_multiplier=3, auxiliary=False):
super().__init__()
self.in_channels = in_channels
self.channels = channels
self.n_classes = n_classes
self.n_layers = n_layers
self.aux_pos = 2 * n_layers // 3 if auxiliary else -1
c_cur = stem_multiplier * self.channels
self.stem = nn.Sequential(
nn.Conv2d(in_channels, c_cur, 3, 1, 1, bias=False),
nn.BatchNorm2d(c_cur)
)
# for the first cell, stem is used for both s0 and s1
# [!] channels_pp and channels_p is output channel size, but c_cur is input channel size.
channels_pp, channels_p, c_cur = c_cur, c_cur, channels
self.cells = nn.ModuleList()
reduction_p, reduction = False, False
for i in range(n_layers):
reduction_p, reduction = reduction, False
# Reduce featuremap size and double channels in 1/3 and 2/3 layer.
if i in [n_layers // 3, 2 * n_layers // 3]:
c_cur *= 2
reduction = True
cell = Cell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction)
self.cells.append(cell)
c_cur_out = c_cur * n_nodes
channels_pp, channels_p = channels_p, c_cur_out
#if i == self.aux_pos:
# self.aux_head = AuxiliaryHead(input_size // 4, channels_p, n_classes)
self.gap = nn.AdaptiveAvgPool2d(1)
self.linear = nn.Linear(channels_p, n_classes)
def forward(self, x):
s0 = s1 = self.stem(x)
#aux_logits = None
for i, cell in enumerate(self.cells):
s0, s1 = s1, cell(s0, s1)
#if i == self.aux_pos and self.training:
# aux_logits = self.aux_head(s1)
out = self.gap(s1)
out = out.view(out.size(0), -1) # flatten
logits = self.linear(out)
#if aux_logits is not None:
# return logits, aux_logits
return logits
def drop_path_prob(self, p):
for module in self.modules():
if isinstance(module, ops.DropPath):
module.p = p
if __name__ == '__main__':
base_model = CNN(32, 3, 16, 10, 8)

Просмотреть файл

@ -0,0 +1,133 @@
import torch
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import basic_unit
@basic_unit
class DropPath(nn.Module):
def __init__(self, p=0.):
"""
Drop path with probability.
Parameters
----------
p : float
Probability of an path to be zeroed.
"""
super().__init__()
self.p = p
def forward(self, x):
if self.training and self.p > 0.:
keep_prob = 1. - self.p
# per data point mask
mask = torch.zeros((x.size(0), 1, 1, 1), device=x.device).bernoulli_(keep_prob)
return x / keep_prob * mask
return x
@basic_unit
class PoolBN(nn.Module):
"""
AvgPool or MaxPool with BN. `pool_type` must be `max` or `avg`.
"""
def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True):
super().__init__()
if pool_type.lower() == 'max':
self.pool = nn.MaxPool2d(kernel_size, stride, padding)
elif pool_type.lower() == 'avg':
self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False)
else:
raise ValueError()
self.bn = nn.BatchNorm2d(C, affine=affine)
def forward(self, x):
out = self.pool(x)
out = self.bn(out)
return out
@basic_unit
class StdConv(nn.Module):
"""
Standard conv: ReLU - Conv - BN
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super().__init__()
self.net = nn.Sequential(
nn.ReLU(),
nn.Conv2d(C_in, C_out, kernel_size, stride, padding, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)
def forward(self, x):
return self.net(x)
@basic_unit
class FacConv(nn.Module):
"""
Factorized conv: ReLU - Conv(Kx1) - Conv(1xK) - BN
"""
def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True):
super().__init__()
self.net = nn.Sequential(
nn.ReLU(),
nn.Conv2d(C_in, C_in, (kernel_length, 1), stride, padding, bias=False),
nn.Conv2d(C_in, C_out, (1, kernel_length), stride, padding, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)
def forward(self, x):
return self.net(x)
@basic_unit
class DilConv(nn.Module):
"""
(Dilated) depthwise separable conv.
ReLU - (Dilated) depthwise separable - Pointwise - BN.
If dilation == 2, 3x3 conv => 5x5 receptive field, 5x5 conv => 9x9 receptive field.
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
super().__init__()
self.net = nn.Sequential(
nn.ReLU(),
nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in,
bias=False),
nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)
def forward(self, x):
return self.net(x)
@basic_unit
class SepConv(nn.Module):
"""
Depthwise separable conv.
DilConv(dilation=1) * 2.
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super().__init__()
self.net = nn.Sequential(
DilConv(C_in, C_in, kernel_size, stride, padding, dilation=1, affine=affine),
DilConv(C_in, C_out, kernel_size, 1, padding, dilation=1, affine=affine)
)
def forward(self, x):
return self.net(x)
@basic_unit
class FactorizedReduce(nn.Module):
"""
Reduce feature map size by factorized pointwise (stride=2).
"""
def __init__(self, C_in, C_out, affine=True):
super().__init__()
self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine)
def forward(self, x):
x = self.relu(x)
out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
return out

Просмотреть файл

@ -0,0 +1,54 @@
import json
import os
import sys
import torch
from pathlib import Path
import nni.retiarii.evaluator.pytorch.cgo.evaluator as cgo
import nni.retiarii.evaluator.pytorch.lightning as pl
import nni.retiarii.strategy as strategy
from nni.retiarii import serialize
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from torchvision import transforms
from torchvision.datasets import CIFAR10
from darts_model import CNN
if __name__ == '__main__':
base_model = CNN(32, 3, 16, 10, 8)
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
valid_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_dataset = serialize(CIFAR10, root='data/cifar10', train=True, download=True, transform=train_transform)
test_dataset = serialize(CIFAR10, root='data/cifar10', train=False, download=True, transform=valid_transform)
trainer = cgo.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
max_epochs=1, limit_train_batches=0.2)
simple_strategy = strategy.Random()
exp = RetiariiExperiment(base_model, trainer, [], simple_strategy)
exp_config = RetiariiExeConfig('local')
exp_config.experiment_name = 'darts_search'
exp_config.execution_engine = 'cgo'
exp_config.trial_concurrency = 3
# since CGO may merge multiple trials into one, RetiariiExperiment may run more trials than max_trial_number
# when max_trial_number = 3, it actually runs 9 models since each merged trial contains 3 trials from strategy
exp_config.max_trial_number = 100
exp_config.devices = ['cuda:0', 'cuda:1', 'cuda:2']
exp_config.trial_gpu_number = 1
exp_config.batch_waiting_time = 100
exp_config.training_service.use_active_gpu = True
exp_config.training_service.gpu_indices = [0, 1, 2]
exp.run(exp_config, 8081)

Просмотреть файл

@ -0,0 +1,298 @@
from nni.retiarii import basic_unit
import nni.retiarii.nn.pytorch as nn
import warnings
import torch
import torch.nn as torch_nn
from torchvision.models.utils import load_state_dict_from_url
import torch.nn.functional as F
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parents[2]))
# Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is
# 1.0 - tensorflow.
_BN_MOMENTUM = 1 - 0.9997
_FIRST_DEPTH = 32
_MOBILENET_V2_FILTERS = [16, 24, 32, 64, 96, 160, 320]
_MOBILENET_V2_NUM_LAYERS = [1, 2, 3, 4, 3, 3, 1]
class _ResidualBlock(nn.Module):
def __init__(self, net):
super().__init__()
self.net = net
def forward(self, x):
return self.net(x) + x
class _InvertedResidual(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor, skip, bn_momentum=0.1):
super(_InvertedResidual, self).__init__()
assert stride in [1, 2]
assert kernel_size in [3, 5]
mid_ch = in_ch * expansion_factor
self.apply_residual = skip and in_ch == out_ch and stride == 1
self.layers = nn.Sequential(
# Pointwise
nn.Conv2d(in_ch, mid_ch, 1, bias=False),
nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
nn.ReLU(inplace=True),
# Depthwise
nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2,
stride=stride, groups=mid_ch, bias=False),
nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
nn.ReLU(inplace=True),
# Linear pointwise. Note that there's no activation.
nn.Conv2d(mid_ch, out_ch, 1, bias=False),
nn.BatchNorm2d(out_ch, momentum=bn_momentum))
def forward(self, input):
if self.apply_residual:
ret = self.layers(input) + input
else:
ret = self.layers(input)
return ret
def _stack_inverted_residual(in_ch, out_ch, kernel_size, skip, stride, exp_factor, repeats, bn_momentum):
""" Creates a stack of inverted residuals. """
assert repeats >= 1
# First one has no skip, because feature map size changes.
first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor, skip, bn_momentum=bn_momentum)
remaining = []
for _ in range(1, repeats):
remaining.append(_InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor, skip, bn_momentum=bn_momentum))
return nn.Sequential(first, *remaining)
def _stack_normal_conv(in_ch, out_ch, kernel_size, skip, dconv, stride, repeats, bn_momentum):
assert repeats >= 1
stack = []
for i in range(repeats):
s = stride if i == 0 else 1
if dconv:
modules = [
nn.Conv2d(in_ch, in_ch, kernel_size, padding=kernel_size // 2, stride=s, groups=in_ch, bias=False),
nn.BatchNorm2d(in_ch, momentum=bn_momentum),
nn.ReLU(inplace=True),
nn.Conv2d(in_ch, out_ch, 1, padding=0, stride=1, bias=False),
nn.BatchNorm2d(out_ch, momentum=bn_momentum)
]
else:
modules = [
nn.Conv2d(in_ch, out_ch, kernel_size, padding=kernel_size // 2, stride=s, bias=False),
nn.ReLU(inplace=True),
nn.BatchNorm2d(out_ch, momentum=bn_momentum)
]
if skip and in_ch == out_ch and s == 1:
# use different implementation for skip and noskip to align with pytorch
stack.append(_ResidualBlock(nn.Sequential(*modules)))
else:
stack += modules
in_ch = out_ch
return stack
def _round_to_multiple_of(val, divisor, round_up_bias=0.9):
""" Asymmetric rounding to make `val` divisible by `divisor`. With default
bias, will round up, unless the number is no more than 10% greater than the
smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88. """
assert 0.0 < round_up_bias < 1.0
new_val = max(divisor, int(val + divisor / 2) // divisor * divisor)
return new_val if new_val >= round_up_bias * val else new_val + divisor
def _get_depths(depths, alpha):
""" Scales tensor depths as in reference MobileNet code, prefers rouding up
rather than down. """
return [_round_to_multiple_of(depth * alpha, 8) for depth in depths]
class MNASNet(nn.Module):
""" MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This
implements the B1 variant of the model.
>>> model = MNASNet(1000, 1.0)
>>> x = torch.rand(1, 3, 224, 224)
>>> y = model(x)
>>> y.dim()
1
>>> y.nelement()
1000
"""
# Version 2 adds depth scaling in the initial stages of the network.
_version = 2
def __init__(self, alpha, depths, convops, kernel_sizes, num_layers,
skips, num_classes=1000, dropout=0.2):
super().__init__()
assert alpha > 0.0
assert len(depths) == len(convops) == len(kernel_sizes) == len(num_layers) == len(skips) == 7
self.alpha = alpha
self.num_classes = num_classes
depths = _get_depths([_FIRST_DEPTH] + depths, alpha)
base_filter_sizes = [16, 24, 40, 80, 96, 192, 320]
exp_ratios = [3, 3, 3, 6, 6, 6, 6]
strides = [1, 2, 2, 2, 1, 2, 1]
layers = [
# First layer: regular conv.
nn.Conv2d(3, depths[0], 3, padding=1, stride=2, bias=False),
nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True),
]
count = 0
# for conv, prev_depth, depth, ks, skip, stride, repeat, exp_ratio in \
# zip(convops, depths[:-1], depths[1:], kernel_sizes, skips, strides, num_layers, exp_ratios):
for filter_size, exp_ratio, stride in zip(base_filter_sizes, exp_ratios, strides):
# TODO: restrict that "choose" can only be used within mutator
ph = nn.Placeholder(label=f'mutable_{count}', **{
'kernel_size_options': [1, 3, 5],
'n_layer_options': [1, 2, 3, 4],
'op_type_options': ['__mutated__.base_mnasnet.RegularConv',
'__mutated__.base_mnasnet.DepthwiseConv',
'__mutated__.base_mnasnet.MobileConv'],
# 'se_ratio_options': [0, 0.25],
'skip_options': ['identity', 'no'],
'n_filter_options': [int(filter_size*x) for x in [0.75, 1.0, 1.25]],
'exp_ratio': exp_ratio,
'stride': stride,
'in_ch': depths[0] if count == 0 else None
})
layers.append(ph)
'''if conv == "mconv":
# MNASNet blocks: stacks of inverted residuals.
layers.append(_stack_inverted_residual(prev_depth, depth, ks, skip,
stride, exp_ratio, repeat, _BN_MOMENTUM))
else:
# Normal conv and depth-separated conv
layers += _stack_normal_conv(prev_depth, depth, ks, skip, conv == "dconv",
stride, repeat, _BN_MOMENTUM)'''
count += 1
if count >= 2:
break
layers += [
# Final mapping to classifier input.
nn.Conv2d(depths[7], 1280, 1, padding=0, stride=1, bias=False),
nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True),
]
self.layers = nn.Sequential(*layers)
self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True),
nn.Linear(1280, num_classes))
self._initialize_weights()
#self.for_test = 10
def forward(self, x):
# if self.for_test == 10:
x = self.layers(x)
# Equivalent to global avgpool and removing H and W dimensions.
x = x.mean([2, 3])
x = F.relu(x)
return self.classifier(x)
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
torch_nn.init.kaiming_normal_(m.weight, mode="fan_out",
nonlinearity="relu")
if m.bias is not None:
torch_nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
torch_nn.init.ones_(m.weight)
torch_nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
torch_nn.init.kaiming_uniform_(m.weight, mode="fan_out",
nonlinearity="sigmoid")
torch_nn.init.zeros_(m.bias)
def test_model(model):
model(torch.randn(2, 3, 224, 224))
# ====================definition of candidate op classes
BN_MOMENTUM = 1 - 0.9997
class RegularConv(nn.Module):
def __init__(self, kernel_size, in_ch, out_ch, skip, exp_ratio, stride):
super().__init__()
self.kernel_size = kernel_size
self.in_ch = in_ch
self.out_ch = out_ch
self.skip = skip
self.exp_ratio = exp_ratio
self.stride = stride
self.conv = nn.Conv2d(in_ch, out_ch, kernel_size, padding=kernel_size // 2, stride=stride, bias=False)
self.relu = nn.ReLU(inplace=True)
self.bn = nn.BatchNorm2d(out_ch, momentum=BN_MOMENTUM)
def forward(self, x):
out = self.bn(self.relu(self.conv(x)))
if self.skip == 'identity':
out = out + x
return out
class DepthwiseConv(nn.Module):
def __init__(self, kernel_size, in_ch, out_ch, skip, exp_ratio, stride):
super().__init__()
self.kernel_size = kernel_size
self.in_ch = in_ch
self.out_ch = out_ch
self.skip = skip
self.exp_ratio = exp_ratio
self.stride = stride
self.conv1 = nn.Conv2d(in_ch, in_ch, kernel_size, padding=kernel_size // 2, stride=stride, groups=in_ch, bias=False)
self.bn1 = nn.BatchNorm2d(in_ch, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(in_ch, out_ch, 1, padding=0, stride=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_ch, momentum=BN_MOMENTUM)
def forward(self, x):
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
if self.skip == 'identity':
out = out + x
return out
class MobileConv(nn.Module):
def __init__(self, kernel_size, in_ch, out_ch, skip, exp_ratio, stride):
super().__init__()
self.kernel_size = kernel_size
self.in_ch = in_ch
self.out_ch = out_ch
self.skip = skip
self.exp_ratio = exp_ratio
self.stride = stride
mid_ch = in_ch * exp_ratio
self.layers = nn.Sequential(
# Pointwise
nn.Conv2d(in_ch, mid_ch, 1, bias=False),
nn.BatchNorm2d(mid_ch, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True),
# Depthwise
nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=(kernel_size - 1) // 2,
stride=stride, groups=mid_ch, bias=False),
nn.BatchNorm2d(mid_ch, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True),
# Linear pointwise. Note that there's no activation.
nn.Conv2d(mid_ch, out_ch, 1, bias=False),
nn.BatchNorm2d(out_ch, momentum=BN_MOMENTUM))
def forward(self, x):
out = self.layers(x)
if self.skip == 'identity':
out = out + x
return out
# mnasnet0_5
ir_module = _InvertedResidual(16, 16, 3, 1, 1, True)

Просмотреть файл

@ -0,0 +1,64 @@
import logging
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parents[2]))
from nni.retiarii import Mutator
from base_mnasnet import RegularConv, DepthwiseConv, MobileConv
_logger = logging.getLogger(__name__)
class BlockMutator(Mutator):
def __init__(self, target: str):
super(BlockMutator, self).__init__()
self.target = target
def mutate(self, model):
nodes = model.get_nodes_by_label(self.target)
assert len(nodes) == 1
node = nodes[0]
graph = node.graph
related_info = node.operation.parameters
kernel_size = self.choice(related_info['kernel_size_options'])
op_type = self.choice(related_info['op_type_options'])
#self.choice(related_info['se_ratio_options'])
skip = self.choice(related_info['skip_options'])
n_filter = self.choice(related_info['n_filter_options'])
if related_info['in_ch'] is not None:
in_ch = related_info['in_ch']
else:
assert len(node.predecessors) == 1
the_node = node.predecessors[0]
_logger.debug(repr(the_node.operation.parameters))
_logger.debug(the_node.__repr__())
in_ch = the_node.operation.parameters['out_ch']
# update the placeholder to be a new operation
node.update_operation(op_type, {
'kernel_size': kernel_size,
'in_ch': in_ch,
'out_ch': n_filter,
'skip': 'no',
'exp_ratio': related_info['exp_ratio'],
'stride': related_info['stride']
})
# insert new nodes after the placeholder
n_layer = self.choice(related_info['n_layer_options'])
for i in range(1, n_layer):
node = graph.insert_node_on_edge(node.outgoing_edges[0],
'{}_{}'.format(self.target, i),
op_type,
{'kernel_size': kernel_size,
'in_ch': n_filter,
'out_ch': n_filter,
'skip': skip,
'exp_ratio': related_info['exp_ratio'],
'stride': 1})
# fix possible shape mismatch
# TODO: use formal method function to update parameters
if len(node.successors) == 1 and 'in_channels' in node.successors[0].operation.parameters:
node.successors[0].operation.parameters['in_channels'] = n_filter

Просмотреть файл

@ -0,0 +1,80 @@
import os
import sys
import torch
from pathlib import Path
import nni.retiarii.evaluator.pytorch.lightning as pl
import nni.retiarii.evaluator.pytorch.cgo.evaluator as cgo
from nni.retiarii import serialize
from base_mnasnet import MNASNet
from nni.experiment import RemoteMachineConfig
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.strategy import TPEStrategy
from torchvision import transforms
from torchvision.datasets import CIFAR10
from mutator import BlockMutator
if __name__ == '__main__':
_DEFAULT_DEPTHS = [16, 24, 40, 80, 96, 192, 320]
_DEFAULT_CONVOPS = ["dconv", "mconv", "mconv", "mconv", "mconv", "mconv", "mconv"]
_DEFAULT_SKIPS = [False, True, True, True, True, True, True]
_DEFAULT_KERNEL_SIZES = [3, 3, 5, 5, 3, 5, 3]
_DEFAULT_NUM_LAYERS = [1, 3, 3, 3, 2, 4, 1]
base_model = MNASNet(0.5, _DEFAULT_DEPTHS, _DEFAULT_CONVOPS, _DEFAULT_KERNEL_SIZES,
_DEFAULT_NUM_LAYERS, _DEFAULT_SKIPS)
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
valid_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_dataset = serialize(CIFAR10, root='data/cifar10', train=True, download=True, transform=train_transform)
test_dataset = serialize(CIFAR10, root='data/cifar10', train=False, download=True, transform=valid_transform)
# trainer = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
# val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
# max_epochs=1, limit_train_batches=0.2)
trainer = cgo.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
max_epochs=1, limit_train_batches=0.2)
applied_mutators = [
BlockMutator('mutable_0'),
BlockMutator('mutable_1')
]
simple_strategy = TPEStrategy()
exp = RetiariiExperiment(base_model, trainer, applied_mutators, simple_strategy)
exp_config = RetiariiExeConfig('remote')
exp_config.experiment_name = 'darts_search'
exp_config.trial_concurrency = 3
exp_config.max_trial_number = 10
exp_config.trial_gpu_number = 1
exp_config.training_service.use_active_gpu = True
exp_config.training_service.reuse_mode = True
exp_config.training_service.gpu_indices = [0, 1, 2]
exp_config.max_concurrency_cgo = 1
exp_config.batch_waiting_time = 0
rm_conf = RemoteMachineConfig()
rm_conf.host = '127.0.0.1'
rm_conf.user = 'xxx'
rm_conf.password = 'xxx'
rm_conf.port = 22
rm_conf.python_path = '/home/xxx/py38/bin'
rm_conf.gpu_indices = [0, 1, 2]
rm_conf.use_active_gpu = True
rm_conf.max_trial_number_per_gpu = 3
exp_config.training_service.machine_list = [rm_conf]
exp_config.execution_engine = 'cgo'
exp.run(exp_config, 8099)

Просмотреть файл

@ -31,7 +31,8 @@ if __name__ == '__main__':
test_dataset = serialize(CIFAR10, root='data/cifar10', train=False, download=True, transform=valid_transform)
trainer = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
max_epochs=1, limit_train_batches=0.2)
max_epochs=1, limit_train_batches=0.2,
progress_bar_refresh_rate=0)
simple_strategy = strategy.Random()

Просмотреть файл

@ -1,363 +0,0 @@
{
"_model__stem":{
"inputs":[
"_inputs__1"
],
"outputs":[
"pool2__1"
],
"nodes":{
"_model__stem__conv1":{
"operation":{
"type":"__torch__.torch.nn.modules.conv.Conv2d",
"parameters":{
"out_channels":32,
"in_channels":1,
"kernel_size":5
}
}
},
"_model__stem__pool1":{
"operation":{
"type":"__torch__.torch.nn.modules.pooling.MaxPool2d",
"parameters":{
"kernel_size":2
}
}
},
"_model__stem__conv2":{
"operation":{
"type":"__torch__.torch.nn.modules.conv.Conv2d",
"parameters":{
"out_channels":64,
"in_channels":32,
"kernel_size":5
}
}
},
"_model__stem__pool2":{
"operation":{
"type":"__torch__.torch.nn.modules.pooling.MaxPool2d",
"parameters":{
"kernel_size":2
}
}
}
},
"edges":[
{
"head":[
"_inputs",
0
],
"tail":[
"_model__stem__conv1",
0
]
},
{
"head":[
"_model__stem__conv1",
null
],
"tail":[
"_model__stem__pool1",
0
]
},
{
"head":[
"_model__stem__pool1",
null
],
"tail":[
"_model__stem__conv2",
0
]
},
{
"head":[
"_model__stem__conv2",
null
],
"tail":[
"_model__stem__pool2",
0
]
},
{
"head":[
"_model__stem__pool2",
null
],
"tail":[
"_outputs",
null
]
}
]
},
"_model":{
"inputs":[
"image__1"
],
"outputs":[
"softmax__1"
],
"nodes":{
"_model__Constant2":{
"operation":{
"type":"prim::Constant",
"parameters":{
}
}
},
"_model__Constant3":{
"operation":{
"type":"prim::Constant",
"parameters":{
"value":3
}
}
},
"_model__Constant4":{
"operation":{
"type":"prim::Constant",
"parameters":{
"value":-1
}
}
},
"_model__Constant5":{
"operation":{
"type":"prim::Constant",
"parameters":{
"value":0
}
}
},
"_model__stem":{
"operation":{
"type":"_cell",
"parameters":{
},
"cell_name":"_model__stem"
}
},
"_model__Size6":{
"operation":{
"type":"aten::size",
"parameters":{
}
}
},
"_model__ListConstruct7":{
"operation":{
"type":"prim::ListConstruct",
"parameters":{
}
}
},
"_model__View8":{
"operation":{
"type":"aten::view",
"parameters":{
}
}
},
"_model__fc1":{
"operation":{
"type":"__torch__.torch.nn.modules.linear.Linear",
"parameters":{
"in_features":1024,
"out_features":256
}
}
},
"_model__fc2":{
"operation":{
"type":"__torch__.torch.nn.modules.linear.Linear",
"parameters":{
"in_features":256,
"out_features":10
}
}
},
"_model__softmax9":{
"operation":{
"type":"Function.softmax",
"parameters":{
}
}
}
},
"edges":[
{
"head":[
"_inputs",
0
],
"tail":[
"_model__stem",
0
]
},
{
"head":[
"_model__stem",
null
],
"tail":[
"_model__Size6",
0
]
},
{
"head":[
"_model__Constant5",
null
],
"tail":[
"_model__Size6",
1
]
},
{
"head":[
"_model__Size6",
null
],
"tail":[
"_model__ListConstruct7",
0
]
},
{
"head":[
"_model__Constant4",
null
],
"tail":[
"_model__ListConstruct7",
1
]
},
{
"head":[
"_model__stem",
null
],
"tail":[
"_model__View8",
0
]
},
{
"head":[
"_model__ListConstruct7",
null
],
"tail":[
"_model__View8",
1
]
},
{
"head":[
"_model__View8",
null
],
"tail":[
"_model__fc1",
0
]
},
{
"head":[
"_model__fc1",
null
],
"tail":[
"_model__fc2",
0
]
},
{
"head":[
"_model__fc2",
null
],
"tail":[
"_model__softmax9",
0
]
},
{
"head":[
"_model__Constant4",
null
],
"tail":[
"_model__softmax9",
1
]
},
{
"head":[
"_model__Constant3",
null
],
"tail":[
"_model__softmax9",
2
]
},
{
"head":[
"_model__Constant2",
null
],
"tail":[
"_model__softmax9",
3
]
},
{
"head":[
"_model__softmax9",
null
],
"tail":[
"_outputs",
null
]
}
]
},
"_evaluator": {
"module": "nni.retiarii.trainer.PyTorchImageClassificationTrainer",
"kwargs": {
"dataset_cls": "MNIST",
"dataset_kwargs": {
"root": "data/mnist",
"download": true
},
"dataloader_kwargs": {
"batch_size": 32
},
"optimizer_cls" : "SGD",
"optimizer_kwargs": {
"lr": 1e-3
},
"trainer_kwargs": {
"max_epochs": 1
}
}
}
}

Просмотреть файл

@ -1,63 +1,318 @@
import json
import os
import sys
import threading
import unittest
import logging
import time
import torch
import torch.nn as nn
from pathlib import Path
from nni.retiarii.execution.cgo_engine import CGOExecutionEngine
from nni.retiarii.execution.logical_optimizer.logical_plan import LogicalPlan
from nni.retiarii.execution.logical_optimizer.opt_dedup_input import DedupInputOptimizer
from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii import Model, Node
import nni
from nni.retiarii import Model, submit_models
from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.integration import RetiariiAdvisor
from nni.retiarii.evaluator.pytorch import PyTorchImageClassificationTrainer, PyTorchMultiModelTrainer
from nni.retiarii.utils import import_
try:
from nni.common.device import GPUDevice
from nni.retiarii.execution.cgo_engine import CGOExecutionEngine
from nni.retiarii import Model
from nni.retiarii.graph import Node
from nni.retiarii import Model, submit_models
from nni.retiarii.integration import RetiariiAdvisor
from nni.retiarii.execution import set_execution_engine
from nni.retiarii.execution.logical_optimizer.opt_dedup_input import DedupInputOptimizer
from nni.retiarii.execution.logical_optimizer.logical_plan import LogicalPlan
from nni.retiarii.utils import import_
from nni.retiarii import serialize
import nni.retiarii.evaluator.pytorch.lightning as pl
from nni.retiarii.evaluator.pytorch.cgo.evaluator import MultiModelSupervisedLearningModule, _MultiModelSupervisedLearningModule
import nni.retiarii.evaluator.pytorch.cgo.trainer as cgo_trainer
module_import_failed = False
except ImportError:
module_import_failed = True
import pytest
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import Dataset
from sklearn.datasets import load_diabetes
class _model_cpu(nn.Module):
def __init__(self):
super().__init__()
self.M_1_stem = M_1_stem()
self.M_2_stem = M_2_stem()
self.M_1_flatten = torch.nn.Flatten()
self.M_2_flatten = torch.nn.Flatten()
self.M_1_fc1 = torch.nn.Linear(out_features=256, in_features=1024)
self.M_2_fc1 = torch.nn.Linear(out_features=256, in_features=1024)
self.M_1_fc2 = torch.nn.Linear(out_features=10, in_features=256)
self.M_2_fc2 = torch.nn.Linear(out_features=10, in_features=256)
self.M_1_softmax = torch.nn.Softmax()
self.M_2_softmax = torch.nn.Softmax()
def forward(self, *_inputs):
M_1__inputs_to_M_2_stem = _inputs[0]
M_1_stem = self.M_1_stem(_inputs[0])
M_2_stem = self.M_2_stem(M_1__inputs_to_M_2_stem)
M_1_flatten = self.M_1_flatten(M_1_stem)
M_2_flatten = self.M_2_flatten(M_2_stem)
M_1_fc1 = self.M_1_fc1(M_1_flatten)
M_2_fc1 = self.M_2_fc1(M_2_flatten)
M_1_fc2 = self.M_1_fc2(M_1_fc1)
M_2_fc2 = self.M_2_fc2(M_2_fc1)
M_1_softmax = self.M_1_softmax(M_1_fc2)
M_2_softmax = self.M_2_softmax(M_2_fc2)
return M_1_softmax, M_2_softmax
class _model_gpu(nn.Module):
def __init__(self):
super().__init__()
self.M_1_stem = M_1_stem().to('cuda:0')
self.M_2_stem = M_2_stem().to('cuda:1')
self.M_1_flatten = torch.nn.Flatten().to('cuda:0')
self.M_2_flatten = torch.nn.Flatten().to('cuda:1')
self.M_1_fc1 = torch.nn.Linear(out_features=256, in_features=1024).to('cuda:0')
self.M_2_fc1 = torch.nn.Linear(out_features=256, in_features=1024).to('cuda:1')
self.M_1_fc2 = torch.nn.Linear(out_features=10, in_features=256).to('cuda:0')
self.M_2_fc2 = torch.nn.Linear(out_features=10, in_features=256).to('cuda:1')
self.M_1_softmax = torch.nn.Softmax().to('cuda:0')
self.M_2_softmax = torch.nn.Softmax().to('cuda:1')
def forward(self, *_inputs):
M_1__inputs_to_M_1_stem = _inputs[0].to("cuda:0")
M_1__inputs_to_M_2_stem = _inputs[0].to("cuda:1")
M_1_stem = self.M_1_stem(M_1__inputs_to_M_1_stem)
M_2_stem = self.M_2_stem(M_1__inputs_to_M_2_stem)
M_1_flatten = self.M_1_flatten(M_1_stem)
M_2_flatten = self.M_2_flatten(M_2_stem)
M_1_fc1 = self.M_1_fc1(M_1_flatten)
M_2_fc1 = self.M_2_fc1(M_2_flatten)
M_1_fc2 = self.M_1_fc2(M_1_fc1)
M_2_fc2 = self.M_2_fc2(M_2_fc1)
M_1_softmax = self.M_1_softmax(M_1_fc2)
M_2_softmax = self.M_2_softmax(M_2_fc2)
return M_1_softmax, M_2_softmax
class M_1_stem(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(out_channels=32, in_channels=1, kernel_size=5)
self.pool1 = torch.nn.MaxPool2d(kernel_size=2)
self.conv2 = torch.nn.Conv2d(out_channels=64, in_channels=32, kernel_size=5)
self.pool2 = torch.nn.MaxPool2d(kernel_size=2)
def forward(self, *_inputs):
conv1 = self.conv1(_inputs[0])
pool1 = self.pool1(conv1)
conv2 = self.conv2(pool1)
pool2 = self.pool2(conv2)
return pool2
class M_2_stem(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(out_channels=32, in_channels=1, kernel_size=5)
self.pool1 = torch.nn.MaxPool2d(kernel_size=2)
self.conv2 = torch.nn.Conv2d(out_channels=64, in_channels=32, kernel_size=5)
self.pool2 = torch.nn.MaxPool2d(kernel_size=2)
def forward(self, *_inputs):
conv1 = self.conv1(_inputs[0])
pool1 = self.pool1(conv1)
conv2 = self.conv2(pool1)
pool2 = self.pool2(conv2)
return pool2
def _reset():
# this is to not affect other tests in sdk
nni.trial._intermediate_seq = 0
nni.trial._params = {'foo': 'bar', 'parameter_id': 0}
nni.runtime.platform.test._last_metric = None
nni.retiarii.integration_api._advisor = None
nni.retiarii.execution.api._execution_engine = None
def _new_trainer():
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = serialize(MNIST, root='data/mnist', train=True, download=True, transform=transform)
test_dataset = serialize(MNIST, root='data/mnist', train=False, download=True, transform=transform)
multi_module = MultiModelSupervisedLearningModule(nn.CrossEntropyLoss, {'acc': pl._AccuracyWithLogits})
lightning = pl.Lightning(multi_module, cgo_trainer.Trainer(use_cgo=True,
max_epochs=1,
limit_train_batches=0.25,
progress_bar_refresh_rate=0),
train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100))
return lightning
def _load_mnist(n_models: int = 1):
path = Path(__file__).parent / 'converted_mnist_pytorch.json'
path = Path(__file__).parent / 'mnist_pytorch.json'
with open(path) as f:
mnist_model = Model._load(json.load(f))
mnist_model.evaluator = _new_trainer()
if n_models == 1:
return mnist_model
else:
models = [mnist_model]
for i in range(n_models-1):
models.append(mnist_model.fork())
for i in range(n_models - 1):
forked_model = mnist_model.fork()
forked_model.evaluator = _new_trainer()
models.append(forked_model)
return models
@unittest.skip('Skipped in this version')
def _get_final_result():
result = json.loads(nni.runtime.platform.test._last_metric)['value']
if isinstance(result, list):
return [float(_) for _ in result]
else:
if isinstance(result, str) and '[' in result:
return json.loads(result)
return [float(result)]
class CGOEngineTest(unittest.TestCase):
def setUp(self):
if module_import_failed:
self.skipTest('test skip due to failed import of nni.retiarii.evaluator.pytorch.lightning')
def test_multi_model_trainer_cpu(self):
_reset()
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = serialize(MNIST, root='data/mnist', train=True, download=True, transform=transform)
test_dataset = serialize(MNIST, root='data/mnist', train=False, download=True, transform=transform)
multi_module = _MultiModelSupervisedLearningModule(nn.CrossEntropyLoss, {'acc': pl._AccuracyWithLogits}, n_models=2)
lightning = pl.Lightning(multi_module, cgo_trainer.Trainer(use_cgo=True,
max_epochs=1,
limit_train_batches=0.25),
train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100))
lightning._execute(_model_cpu)
result = _get_final_result()
assert len(result) == 2
for _ in result:
assert _ > 0.8
def test_multi_model_trainer_gpu(self):
_reset()
if not (torch.cuda.is_available() and torch.cuda.device_count() >= 2):
pytest.skip('test requires GPU and torch+cuda')
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = serialize(MNIST, root='data/mnist', train=True, download=True, transform=transform)
test_dataset = serialize(MNIST, root='data/mnist', train=False, download=True, transform=transform)
multi_module = _MultiModelSupervisedLearningModule(nn.CrossEntropyLoss, {'acc': pl._AccuracyWithLogits}, n_models=2)
lightning = pl.Lightning(multi_module, cgo_trainer.Trainer(use_cgo=True,
max_epochs=1,
limit_train_batches=0.25),
train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100))
lightning._execute(_model_gpu)
result = _get_final_result()
assert len(result) == 2
for _ in result:
assert _ > 0.8
def _build_logical_with_mnist(self, n_models: int):
lp = LogicalPlan()
models = _load_mnist(n_models=n_models)
for m in models:
lp.add_model(m)
return lp, models
def test_add_model(self):
_reset()
lp, models = self._build_logical_with_mnist(3)
for node in lp.logical_graph.hidden_nodes:
old_nodes = [m.root_graph.get_node_by_id(node.id) for m in models]
self.assertTrue(any([old_nodes[0].__repr__() == Node.__repr__(x) for x in old_nodes]))
def test_dedup_input_four_devices(self):
_reset()
lp, models = self._build_logical_with_mnist(3)
opt = DedupInputOptimizer()
opt.convert(lp)
advisor = RetiariiAdvisor()
available_devices = [GPUDevice("test", 0), GPUDevice("test", 1), GPUDevice("test", 2), GPUDevice("test", 3)]
cgo = CGOExecutionEngine(devices=available_devices, batch_waiting_time=0)
phy_models = cgo._assemble(lp)
self.assertTrue(len(phy_models) == 1)
advisor.stopping = True
advisor.default_worker.join()
advisor.assessor_worker.join()
cgo.join()
def test_dedup_input_two_devices(self):
_reset()
lp, models = self._build_logical_with_mnist(3)
opt = DedupInputOptimizer()
opt.convert(lp)
advisor = RetiariiAdvisor()
available_devices = [GPUDevice("test", 0), GPUDevice("test", 1)]
cgo = CGOExecutionEngine(devices=available_devices, batch_waiting_time=0)
phy_models = cgo._assemble(lp)
self.assertTrue(len(phy_models) == 2)
advisor.stopping = True
advisor.default_worker.join()
advisor.assessor_worker.join()
cgo.join()
def test_submit_models(self):
os.environ['CGO'] = 'true'
_reset()
nni.retiarii.debug_configs.framework = 'pytorch'
os.makedirs('generated', exist_ok=True)
from nni.runtime import protocol, platform
from nni.runtime import protocol
import nni.runtime.platform.test as tt
protocol._out_file = open('generated/debug_protocol_out_file.py', 'wb')
protocol._in_file = open('generated/debug_protocol_out_file.py', 'rb')
models = _load_mnist(2)
advisor = RetiariiAdvisor()
cgo_engine = CGOExecutionEngine(devices=[GPUDevice("test", 0), GPUDevice("test", 1),
GPUDevice("test", 2), GPUDevice("test", 3)], batch_waiting_time=0)
set_execution_engine(cgo_engine)
submit_models(*models)
time.sleep(3)
if torch.cuda.is_available() and torch.cuda.device_count() >= 2:
cmd, data = protocol.receive()
params = json.loads(data)
params['parameters']['training_kwargs']['max_steps'] = 100
tt.init_params(params)
trial_thread = threading.Thread(target=CGOExecutionEngine.trial_execute_graph())
trial_thread = threading.Thread(target=CGOExecutionEngine.trial_execute_graph)
trial_thread.start()
last_metric = None
while True:
@ -66,15 +321,20 @@ class CGOEngineTest(unittest.TestCase):
metric = tt.get_last_metric()
if metric == last_metric:
continue
if 'value' in metric:
metric['value'] = json.dumps(metric['value'])
advisor.handle_report_metric_data(metric)
last_metric = metric
if not trial_thread.is_alive():
trial_thread.join()
break
trial_thread.join()
advisor.stopping = True
advisor.default_worker.join()
advisor.assessor_worker.join()
cgo_engine.join()
if __name__ == '__main__':

Просмотреть файл

@ -1,86 +0,0 @@
import json
import os
import sys
import threading
import unittest
import logging
import time
from pathlib import Path
from nni.retiarii.execution.cgo_engine import CGOExecutionEngine
from nni.retiarii.execution.logical_optimizer.logical_plan import LogicalPlan
from nni.retiarii.execution.logical_optimizer.opt_dedup_input import DedupInputOptimizer
from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii import Model, Node
from nni.retiarii import Model, submit_models
from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.integration import RetiariiAdvisor
from nni.retiarii.utils import import_
def _load_mnist(n_models: int = 1):
path = Path(__file__).parent / 'converted_mnist_pytorch.json'
with open(path) as f:
mnist_model = Model._load(json.load(f))
if n_models == 1:
return mnist_model
else:
models = [mnist_model]
for i in range(n_models-1):
models.append(mnist_model.fork())
return models
@unittest.skip('Skipped in this version')
class DedupInputTest(unittest.TestCase):
def _build_logical_with_mnist(self, n_models: int):
lp = LogicalPlan()
models = _load_mnist(n_models=n_models)
for m in models:
lp.add_model(m)
return lp, models
def _test_add_model(self):
lp, models = self._build_logical_with_mnist(3)
for node in lp.logical_graph.hidden_nodes:
old_nodes = [m.root_graph.get_node_by_id(node.id) for m in models]
self.assertTrue(any([old_nodes[0].__repr__() == Node.__repr__(x) for x in old_nodes]))
def test_dedup_input(self):
os.environ['CGO'] = 'true'
lp, models = self._build_logical_with_mnist(3)
opt = DedupInputOptimizer()
opt.convert(lp)
with open('dedup_logical_graph.json', 'r') as fp:
correct_dump = fp.readlines()
lp_dump = lp.logical_graph._dump()
self.assertTrue(correct_dump[0] == json.dumps(lp_dump))
advisor = RetiariiAdvisor()
cgo = CGOExecutionEngine()
phy_models = cgo._assemble(lp)
self.assertTrue(len(phy_models) == 1)
# logging.info(phy_models[0][0]._dump())
# script=model_to_pytorch_script(phy_models[0][0], placement = phy_models[0][1])
# logging.info(script)
# with open('generated/debug_dedup_input.py', 'w') as fp:
# fp.write(script)
# sys.path.insert(0, 'generated')
# multi_model = import_('debug_dedup_input.logical_0')
# trainer = PyTorchMultiModelTrainer(
# multi_model(), phy_models[0][0].evaluator.kwargs
# )
# trainer.fit()
advisor.stopping = True
advisor.default_worker.join()
advisor.assessor_worker.join()
if __name__ == '__main__':
unittest.main()

Просмотреть файл

@ -22,6 +22,8 @@ class EngineTest(unittest.TestCase):
self.assertEqual(script.strip(), reference_script.strip())
def test_base_execution_engine(self):
nni.retiarii.integration_api._advisor = None
nni.retiarii.execution.api._execution_engine = None
advisor = RetiariiAdvisor()
set_execution_engine(BaseExecutionEngine())
with open(self.enclosing_dir / 'mnist_pytorch.json') as f:
@ -33,7 +35,8 @@ class EngineTest(unittest.TestCase):
advisor.assessor_worker.join()
def test_py_execution_engine(self):
nni.retiarii.integration_api._advisor = None
nni.retiarii.execution.api._execution_engine = None
advisor = RetiariiAdvisor()
set_execution_engine(PurePythonExecutionEngine())
model = Model._load({