NAS execution engine (stage 3) - CGO (#5360)

This commit is contained in:
Yuge Zhang 2023-02-27 20:23:04 +08:00 коммит произвёл GitHub
Родитель 1e4f3f08d3
Коммит 67a61e1d94
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
12 изменённых файлов: 879 добавлений и 798 удалений

Просмотреть файл

@ -1,4 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .engine import *
from .middleware import CrossGraphOptimization

Просмотреть файл

@ -0,0 +1,148 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
__all__ = ['MultiModelLightningModule', 'MultiModelTrainer']
import pytorch_lightning
import torch
from pytorch_lightning.strategies import SingleDeviceStrategy
from torch import nn
from torchmetrics import Metric
import nni
from nni.nas.evaluator.pytorch.lightning import LightningModule
class MultiModelLightningModule(LightningModule):
"""The lightning module for a merged "multi-model".
The output of the multi-model is expected to be a tuple of tensors.
The tensors will be each passed to a criterion and a metric.
The loss will be added up for back propagation, and the metrics will be logged.
The reported metric will be a list of metrics, one for each model.
Parameters
----------
criterion
Loss function.
metric
Metric function.
n_models
Number of models in the multi-model.
"""
def __init__(self, criterion: nn.Module, metric: Metric, n_models: int | None = None):
super().__init__()
self.criterion = criterion
self.metric = metric
self.n_models = n_models
def _dump(self) -> dict:
return {
'criterion': self.criterion,
'metric': self.metric,
'n_models': self.n_models,
}
@staticmethod
def _load(criterion: nn.Module, metric: Metric, n_models: int | None = None) -> MultiModelLightningModule:
return MultiModelLightningModule(criterion, metric, n_models)
def forward(self, x):
y_hat = self.model(x)
return y_hat
def training_step(self, batch, batch_idx):
x, y = batch
multi_y_hat = self(x)
if isinstance(multi_y_hat, tuple):
assert len(multi_y_hat) == self.n_models
else:
assert self.n_models == 1
multi_y_hat = [multi_y_hat]
multi_loss = []
for idx, y_hat in enumerate(multi_y_hat):
loss = self.criterion(y_hat.to("cpu"), y.to("cpu"))
self.log(f'train_loss_{idx}', loss, prog_bar=True)
self.log(f'train_metric_{idx}', self.metric(y_hat.to("cpu"), y.to("cpu")), prog_bar=True)
multi_loss.append(loss)
return sum(multi_loss)
def validation_step(self, batch, batch_idx):
x, y = batch
multi_y_hat = self(x)
if isinstance(multi_y_hat, tuple):
assert len(multi_y_hat) == self.n_models
else:
assert self.n_models == 1
multi_y_hat = [multi_y_hat]
for idx, y_hat in enumerate(multi_y_hat):
self.log(f'val_loss_{idx}', self.criterion(y_hat.to("cpu"), y.to("cpu")), prog_bar=True)
self.log(f'val_metric_{idx}', self.metric(y_hat.to("cpu"), y.to("cpu")), prog_bar=True)
def test_step(self, batch, batch_idx):
x, y = batch
multi_y_hat = self(x)
if isinstance(multi_y_hat, tuple):
assert len(multi_y_hat) == self.n_models
else:
assert self.n_models == 1
multi_y_hat = [multi_y_hat]
for idx, y_hat in enumerate(multi_y_hat):
self.log(f'test_loss_{idx}', self.criterion(y_hat.to("cpu"), y.to("cpu")), prog_bar=True)
self.log(f'test_metric_{idx}', self.metric(y_hat.to("cpu"), y.to("cpu")), prog_bar=True)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
def on_validation_epoch_end(self):
nni.report_intermediate_result(self._get_validation_metrics())
def teardown(self, stage):
if stage == 'fit':
nni.report_final_result(self._get_validation_metrics())
def _get_validation_metrics(self):
# TODO: split metric of multiple models?
assert self.n_models is not None
return [self.trainer.callback_metrics[f'val_metric_{idx}'].item() for idx in range(self.n_models)]
class _BypassStrategy(SingleDeviceStrategy):
strategy_name = "single_device"
def model_to_device(self) -> None:
pass
@nni.trace
class MultiModelTrainer(pytorch_lightning.Trainer):
"""
Trainer for cross-graph optimization.
Parameters
----------
use_cgo
Whether cross-graph optimization (CGO) is used.
If it is True, CGO will manage device placement.
Any device placement from pytorch lightning will be bypassed.
default: False
trainer_kwargs
Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
"""
def __init__(self, use_cgo: bool = True, **trainer_kwargs):
if use_cgo:
if "accelerator" in trainer_kwargs:
raise ValueError("accelerator should not be set when cross-graph optimization is enabled.")
if 'strategy' in trainer_kwargs:
raise ValueError("MultiModelTrainer does not support specifying strategy")
trainer_kwargs['strategy'] = _BypassStrategy()
super().__init__(**trainer_kwargs)

Просмотреть файл

@ -2,13 +2,13 @@
# Licensed under the MIT license.
import copy
from typing import Dict, Tuple, Any
from typing import Dict, Tuple, Any, Type
from nni.retiarii.utils import uid
from nni.common.device import Device, CPUDevice
from nni.mutable.utils import uid
from nni.nas.execution.common.graph import Cell, Edge, Graph, Model, Node
from nni.nas.execution.common.graph_op import Operation, _IOPseudoOperation
from nni.nas.space import Edge, Graph, GraphModelSpace, Node
from nni.nas.space.graph_op import Cell, Operation, _IOPseudoOperation
class AbstractLogicalNode(Node):
@ -16,7 +16,7 @@ class AbstractLogicalNode(Node):
super().__init__(graph, node_id, name, operation, _internal=_internal)
self.related_models = []
def assemble(self, multi_model_placement: Dict[Model, Device]) -> Tuple[Node, Device]:
def assemble(self, multi_model_placement: Dict[GraphModelSpace, Device]) -> Tuple[Node, Device]:
"""
Given a set of models to be formed in a physical model and their device placement,
this function replaces the logical node with an executable physical node for the physical model.
@ -42,7 +42,7 @@ class AbstractLogicalNode(Node):
class LogicalGraph(Graph):
def __init__(self, model: Model, graph_id: int, name: str = None, _internal: bool = False):
def __init__(self, model: GraphModelSpace, graph_id: int, name: str = None, _internal: bool = False):
super().__init__(model, graph_id, name='logical_' + name, _internal=_internal)
def _dump(self) -> Any:
@ -71,7 +71,7 @@ class LogicalGraph(Graph):
'edges': edges_dump
}
def _fork_to(self, model: Model) -> Graph:
def _fork_to(self, model: GraphModelSpace) -> Graph:
new_graph = Graph(model, self.id, self.name,
_internal=True)._register()
@ -106,7 +106,7 @@ class OriginNode(AbstractLogicalNode):
self.original_graph = original_graph
self.original_node = original_node
def assemble(self, multi_model_placement: Dict[Model, Device]) -> Tuple[Node, Device]:
def assemble(self, multi_model_placement: Dict[GraphModelSpace, Device]) -> Tuple[Node, Device]:
model_id = self.original_node.graph.model.model_id
new_node = Node(self.original_node.graph, self.original_node.id,
f"M_{model_id}_" +
@ -124,15 +124,17 @@ class OriginNode(AbstractLogicalNode):
class LogicalPlan:
def __init__(self, plan_id=0) -> None:
self.lp_model = Model(_internal=True)
def __init__(self, model_cls: Type[GraphModelSpace], plan_id: int = 0) -> None:
# GraphModelSpace has multiple implementations based on the framework.
self.model_cls = model_cls
self.lp_model = model_cls(_internal=True)
self.id = plan_id
self.logical_graph = LogicalGraph(
self.lp_model, self.id, name=f'{self.id}', _internal=True)._register()
self.lp_model._root_graph_name = self.logical_graph.name
self.models = []
def add_model(self, model: Model):
def add_model(self, model: GraphModelSpace):
self.models.append(model)
# Only optimize the root graph.
self._merge_graph(model.root_graph)
@ -152,8 +154,8 @@ class LogicalPlan:
new_tail = id_to_new_node[edge.tail.id]
Edge((new_head, edge.head_slot), (new_tail, edge.tail_slot), _internal=True)._register()
def assemble(self, multi_model_placement: Dict[Model, Device]) \
-> Tuple[Model, Dict[Node, Device]]:
def assemble(self, multi_model_placement: Dict[GraphModelSpace, Device]) \
-> Tuple[GraphModelSpace, Dict[Node, Device]]:
"""
Given a set of models to be formed in a physical model and their device placement,
this function replaces all the logical node in this LogicalPlan with executable physical nodes
@ -167,13 +169,13 @@ class LogicalPlan:
Returns
-------
phy_model : Model
phy_model : GraphModelSpace
the physical model formed by models in `multi_model_placement`
all logical node are replaced by physical nodes
node_placements : dict
the device placement of the nodes in `phy_model`
"""
phy_model = Model(_internal=True)
phy_model = self.model_cls(_internal=True)
phy_graph = self.lp_model.root_graph._fork_to(phy_model)
phy_graph._rename_graph(phy_graph.name, "_model")

Просмотреть файл

@ -3,17 +3,12 @@
from typing import List, Dict, Tuple
from nni.nas.utils import uid
from nni.nas.evaluator.pytorch.cgo.evaluator import MultiModelSupervisedLearningModule
from nni.mutable.utils import uid
from nni.common.device import GPUDevice
from nni.nas.execution.common.graph import Graph, Model, Node
from nni.nas.space import GraphModelSpace, Graph, Node
from .interface import AbstractOptimizer
from .logical_plan import (AbstractLogicalNode, LogicalGraph, LogicalPlan,
OriginNode)
_supported_evaluators = [MultiModelSupervisedLearningModule]
from .logical_plan import AbstractLogicalNode, LogicalGraph, LogicalPlan, OriginNode
class DedupInputNode(AbstractLogicalNode):
@ -23,7 +18,6 @@ class DedupInputNode(AbstractLogicalNode):
These models will share the result of once calculation.
"""
def __init__(self, logical_graph: LogicalGraph, node_id: int,
nodes_to_dedup: List[Node], _internal=False):
super().__init__(logical_graph, node_id,
@ -32,7 +26,7 @@ class DedupInputNode(AbstractLogicalNode):
self.origin_nodes: List[OriginNode] = nodes_to_dedup.copy()
self.related_models = [_.original_graph.model for _ in self.origin_nodes]
def assemble(self, multi_model_placement: Dict[Model, GPUDevice]) -> Tuple[Node, GPUDevice]:
def assemble(self, multi_model_placement: Dict[GraphModelSpace, GPUDevice]) -> Tuple[Node, GPUDevice]:
for node in self.origin_nodes:
if node.original_graph.model in multi_model_placement:
new_node = Node(node.original_graph, node.id,
@ -54,10 +48,10 @@ class DedupInputOptimizer(AbstractOptimizer):
pass
def _check_supported_evaluator(self, evaluator):
for e in _supported_evaluators:
if isinstance(evaluator, e):
return True
return False
# NOTE(yuge): I think this is buggy. But I'm not sure whether I should fix it.
from nni.nas.execution.cgo.evaluator import MultiModelLightningModule
_supported_evaluators = (MultiModelLightningModule, )
return isinstance(evaluator, _supported_evaluators)
def _check_deduplicate_by_node(self, root_node, node_to_check):
if root_node == node_to_check:

Просмотреть файл

@ -0,0 +1,394 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
__all__ = ['CrossGraphOptimization']
import logging
import time
import threading
from collections.abc import Iterable
from typing import List, Dict, Tuple, cast
from nni.common.device import GPUDevice, Device
from nni.experiment.config.training_services import RemoteConfig
from nni.nas.space import GraphModelSpace, Node, ModelStatus, ExecutableModelSpace
from nni.nas.execution.engine import Middleware, ExecutionEngine
from nni.nas.execution.event import ModelEventType, IntermediateMetricEvent, FinalMetricEvent, TrainingEndEvent
from nni.typehint import TrialMetric
from .logical_optimizer.logical_plan import LogicalPlan, AbstractLogicalNode
from .logical_optimizer.opt_dedup_input import DedupInputOptimizer
_logger = logging.getLogger(__name__)
class CrossGraphOptimization(Middleware):
"""
The execution engine middleware of Cross-Graph Optimization (CGO).
It's a technique that merges multiple models into one model for training speedup.
See `Retiarii paper <https://www.usenix.org/system/files/osdi20-zhang_quanlu.pdf>`__ for details.
Currently, :class:`CrossGraphOptimization` is only a prototype.
It's not fully tested, and also, comes with a bunch of constraints on the model space and evaluator:
- The models must be in the format of :class:`~nni.nas.space.GraphModelSpace`.
- The evaluator has to be a :class:`~nni.nas.evaluator.pytorch.Lightning` evaluator.
- The ``lightning_module`` argument of the evaluator must be an instance of
:class:`~nni.nas.execution.cgo.evaluator.MultiModelSupervisedLearningModule`.
- The ``trainer`` argument of the evaluator must be an instance of
:class:`~nni.nas.execution.cgo.evaluator.MultiModelTrainer`.
There are also a number of limitations:
- CGO doesn't support stop and resume a checkpoint.
- Only remote training service is supported.
- All model history are stored in memory. The experiment might not scale well.
Parameters
----------
remote_config
The remote training service config.
max_concurrency
The maximum number of trials to run concurrently.
batch_waiting_time
Seconds to wait for each batch of trial submission.
The trials within one batch could apply cross-graph optimization.
"""
def __init__(self, remote_config: RemoteConfig,
max_concurrency: int | None = None,
batch_waiting_time: int = 60) -> None:
super().__init__()
_logger.warning('Cross graph optimization is an experimental feature. Usages are subject to change.')
self._history: List[GraphModelSpace] = []
self._running_models: Dict[int, GraphModelSpace] = {}
self.logical_plan_counter = 0
self.available_devices: List[Device] = []
self.max_concurrency: int | None = max_concurrency
devices = self._construct_devices(remote_config)
for device in devices:
self.available_devices.append(device)
self.all_devices = self.available_devices.copy()
self._batch_waiting_time = batch_waiting_time # seconds to wait for all models in a batch to do cross-graph optimization
self._optimizers = [DedupInputOptimizer()]
self._original_models: Dict[int, GraphModelSpace] = {}
self._original_model_to_multi_model: Dict[int, GraphModelSpace] = {}
self._trial_to_original_models: Dict[int, List[GraphModelSpace]] = {}
self._trial_used_devices: Dict[int, List[Device]] = {}
self._queuing_models: List[GraphModelSpace] = []
self._models_to_retry: List[GraphModelSpace] = []
self._queue_lock = threading.Lock()
self._stopped = False
self._consumer_thread = threading.Thread(target=self._consume_models)
self._consumer_thread.start()
def _construct_devices(self, training_service):
devices = []
if hasattr(training_service, 'machine_list'):
for machine in cast(RemoteConfig, training_service).machine_list:
assert machine.gpu_indices is not None, \
'gpu_indices must be set in RemoteMachineConfig for CGO execution engine'
assert isinstance(machine.gpu_indices, list), 'gpu_indices must be a list'
for gpu_idx in machine.gpu_indices:
devices.append(GPUDevice(machine.host, gpu_idx))
return devices
def shutdown(self):
self._stopped = True
self._consumer_thread.join()
self.engine.unregister_model_event_callback(ModelEventType.TrainingEnd, self._training_end_callback)
self.engine.unregister_model_event_callback(ModelEventType.FinalMetric, self._final_metric_callback)
self.engine.unregister_model_event_callback(ModelEventType.IntermediateMetric, self._intermediate_metric_callback)
self.engine.shutdown()
def load_state_dict(self, state_dict: dict) -> None:
_logger.info('Cross graph optimization does not preserve any states by itself. Loading the state of inner engine: %s', self.engine)
return self.engine.load_state_dict(state_dict)
def state_dict(self) -> dict:
return self.engine.state_dict()
def set_engine(self, engine: ExecutionEngine) -> None:
super().set_engine(engine)
self.engine.register_model_event_callback(ModelEventType.TrainingEnd, self._training_end_callback)
self.engine.register_model_event_callback(ModelEventType.FinalMetric, self._final_metric_callback)
self.engine.register_model_event_callback(ModelEventType.IntermediateMetric, self._intermediate_metric_callback)
def add_optimizer(self, opt):
self._optimizers.append(opt)
def submit_models(self, *models: GraphModelSpace) -> None:
if any(not isinstance(model, GraphModelSpace) for model in models):
raise TypeError('Cross graph optimization only supports GraphModelSpace.')
curr_time = time.time()
_logger.info('%d models are submitted.', len(models))
with self._queue_lock:
self._queuing_models.extend([(curr_time, _) for _ in models])
self._history.extend(models)
def _submit_retry_models(self, models: List[GraphModelSpace]) -> None:
_logger.info('%d models are retried.', len(models))
with self._queue_lock:
self._models_to_retry.extend(models)
def _consume_models(self):
# a thread to monitor self._models_to_retry and self._queuing_models to consume them in batch
while not self._stopped:
# retrying jobs should be first scheduled.
while self._models_to_retry:
with self._queue_lock:
# Get next model and lock the resource.
if len(self.available_devices) > 0:
m = self._models_to_retry[0]
self._models_to_retry = self._models_to_retry[1:]
m = self._schedule_models_in_batch(m)
else:
break
# submit the single model to avoid cross-graph optimization.
self.engine.submit_models(*m)
time.sleep(1)
# Submit merged models
merged_models = []
with self._queue_lock:
curr_time = time.time()
num_models_to_submit = len(self.available_devices)
if self.max_concurrency is not None:
num_models_to_submit = min(num_models_to_submit, self.max_concurrency)
if self._queuing_models and curr_time - self._queuing_models[0][0] >= self._batch_waiting_time:
num_models_to_submit = min(num_models_to_submit, len(self._queuing_models))
if num_models_to_submit > 0:
merged_models = list(self._schedule_models_in_batch(*[_[1] for _ in self._queuing_models[:num_models_to_submit]]))
self._queuing_models = self._queuing_models[num_models_to_submit:]
_logger.debug('Scheduled %d models in batch.', num_models_to_submit)
# Outside lock to avoid deadlock.
if merged_models:
self.engine.submit_models(*merged_models)
time.sleep(1)
def _schedule_models_in_batch(self, *models: GraphModelSpace) -> Iterable[GraphModelSpace]:
_logger.info('%d models are scheduled in batch.', len(models))
_logger.debug('Scheduled model ids: %s', [m.model_id for m in models])
for model in models:
model.status = ModelStatus.Training
logical = self._build_logical(models)
for opt in self._optimizers:
opt.convert(logical)
for model, grouped_models in self._assemble(logical):
assert model.placement is not None
_logger.debug('Created grouped model %d. Original model ids: %s', model.model_id, [m.model_id for m in grouped_models])
# unique non-cpu devices used by the trial
self._trial_used_devices[model.model_id] = list(set([_ for _ in model.placement.values() if isinstance(_, GPUDevice)]))
_logger.debug('Model %d uses devices: %s', model.model_id, self._trial_used_devices[model.model_id])
# currently, it is impossible for search strategy to submit models more than the number of available devices
for used_device in self._trial_used_devices[model.model_id]:
self.available_devices.remove(used_device) # used_device must be in self.available_devices
self._running_models[model.model_id] = model
self._trial_to_original_models[model.model_id] = []
for m in grouped_models:
self._original_models[m.model_id] = m
self._original_model_to_multi_model[m.model_id] = model
self._trial_to_original_models[model.model_id].append(m.model_id)
yield model
def list_models(self) -> Iterable[GraphModelSpace]:
return self._history
def idle_worker_available(self) -> bool:
# the _queuing_models need to use available_devices first
with self._queue_lock:
available_for_more_models = len(self.available_devices) - len(self._queuing_models) - len(self._models_to_retry)
return available_for_more_models
def budget_available(self) -> bool:
return self.engine.budget_available()
def _assemble(self, logical_plan: LogicalPlan) -> Iterable[Tuple[GraphModelSpace, List[GraphModelSpace]]]:
"""
Return the assembled models as a list of tuple.
Each tuple contains the assembled model, the device placement of graph nodes, and the original models.
"""
# try to use the available_devices first so that it can be launched as early as possible
# if free devices are not enough to assemble all models in one trial, try all devices
if len(self.available_devices) > 0:
grouped_models: List[Dict[GraphModelSpace, Device]] = AssemblePolicy().group(logical_plan, self.available_devices)
if len(self.available_devices) == 0 or len(grouped_models) > 1:
grouped_models: List[Dict[GraphModelSpace, Device]] = AssemblePolicy().group(logical_plan, self.all_devices)
for multi_model in grouped_models:
model, model_placement = logical_plan.assemble(multi_model)
assert isinstance(model, GraphModelSpace), 'Assembled model must be a GraphModelSpace.'
from nni.nas.evaluator.pytorch import Lightning
from .evaluator import MultiModelLightningModule, MultiModelTrainer
if not isinstance(model.evaluator, Lightning):
raise TypeError('Cross-graph optimization only supports pytorch lighting as evaluator.')
if not isinstance(model.evaluator.module, MultiModelLightningModule):
raise TypeError('Cross-graph optimization only support MultiModelLightningModule')
if not isinstance(model.evaluator.trainer, MultiModelTrainer):
raise TypeError('Cross-graph optimization only support MultiModelTrainer')
# Set n_models of the lightning module.
model.evaluator.module.n_models = len(multi_model)
model.status = ModelStatus.Frozen
model.placement = model_placement
model.metrics.strict = False
yield model, multi_model.keys()
def _build_logical(self, models: List[GraphModelSpace]) -> LogicalPlan:
assert len(models) > 0
logical_plan = LogicalPlan(model_cls=models[0].__class__, plan_id=self.logical_plan_counter)
for model in models:
logical_plan.add_model(model)
self.logical_plan_counter += 1
return logical_plan
def _training_end_callback(self, event: TrainingEndEvent) -> None:
model = cast(GraphModelSpace, event.model)
_logger.debug(f'Training end for merged model {model.model_id}.')
model = self._running_models[model.model_id]
models_to_retry = []
for model_id in self._original_model_to_multi_model:
if self._original_model_to_multi_model[model_id] == model:
original_model = self._original_models[model_id]
if model.status == ModelStatus.Trained:
self.dispatch_model_event(TrainingEndEvent(original_model, ModelStatus.Trained))
else:
# the failed models in a multi-model will be retried one by one w/o CGO
if len(self._trial_to_original_models[model.model_id]) > 1:
# TODO: should the listeners be notified?
original_model.status = ModelStatus.Frozen
original_model.metrics.clear()
models_to_retry.append(original_model)
else:
self.dispatch_model_event(TrainingEndEvent(original_model, ModelStatus.Failed))
if len(models_to_retry) > 0:
self._submit_retry_models(models_to_retry)
self.available_devices.extend(self._trial_used_devices[model.model_id])
self.available_devices = sorted(list(set(self.available_devices)))
del self._running_models[model.model_id]
def _intermediate_metric_callback(self, event: IntermediateMetricEvent) -> None:
model = cast(GraphModelSpace, event.model)
metrics = cast(List[TrialMetric], event.metric)
_logger.debug(f'Received intermediate metrics for merged model {model.model_id}: {metrics}')
if not isinstance(metrics, Iterable):
raise TypeError('Intermediate metrics must be a list of TrialMetric.')
if len(metrics) != len(self._trial_to_original_models[model.model_id]):
raise ValueError('Number of intermediate metrics must be equal to number of original models.')
merged_metrics: Dict[int, TrialMetric] = {}
for idx, _ in enumerate(metrics):
merged_metrics[self._trial_to_original_models[model.model_id][idx]] = metrics[idx]
for model_id in merged_metrics:
self.dispatch_model_event(IntermediateMetricEvent(self._original_models[model_id], merged_metrics[model_id]))
def _final_metric_callback(self, event: GraphModelSpace) -> None:
model = cast(GraphModelSpace, event.model)
metrics = cast(List[TrialMetric], event.metric.final)
_logger.debug(f'Received final metrics for merged model {model.model_id}: {metrics}')
if not isinstance(metrics, Iterable):
raise TypeError('Final metrics must be a list of TrialMetric.')
if len(metrics) != len(self._trial_to_original_models[model.model_id]):
raise ValueError('Number of final metrics must be equal to number of original models.')
merged_metrics: Dict[int, TrialMetric] = {}
for idx, _ in enumerate(metrics):
merged_metrics[self._trial_to_original_models[model.model_id][idx]] = metrics[idx]
_logger.debug(f'Mapped to metrics of original models: {merged_metrics}')
for model_id in merged_metrics:
self.dispatch_model_event(FinalMetricEvent(self._original_models[model_id], merged_metrics[model_id]))
class AssemblePolicy:
@staticmethod
def _is_related_node(model: GraphModelSpace, node: Node):
if isinstance(node, AbstractLogicalNode):
if model in node.related_models:
return True
else:
if model == node.graph.model:
return True
return False
@staticmethod
def _check_graph_connectivity(model: GraphModelSpace,
group_model: Dict[GraphModelSpace, Device],
logical_plan: LogicalPlan) -> bool:
for edge in logical_plan.logical_graph.edges:
if AssemblePolicy._is_related_node(model, edge.head) or \
AssemblePolicy._is_related_node(model, edge.tail):
for grouped_model in group_model:
if AssemblePolicy._is_related_node(grouped_model, edge.head) or \
AssemblePolicy._is_related_node(grouped_model, edge.tail):
return True
return False
@staticmethod
def _check_evaluator(new_model: GraphModelSpace, group_model: Dict[GraphModelSpace, Device]) -> bool:
from nni.nas.evaluator.pytorch import Lightning
from .evaluator import MultiModelLightningModule, MultiModelTrainer
if not (isinstance(new_model.evaluator, Lightning)
and isinstance(new_model.evaluator.module, MultiModelLightningModule)
and isinstance(new_model.evaluator.trainer, MultiModelTrainer)):
return False
for m in group_model:
if not m.evaluator == new_model.evaluator:
return False
return True
@staticmethod
def group(logical_plan, available_devices):
# TODO: Packing multiple model in one GPU
# Currently, we only support one model per GPU
all_grouped_models = []
group_model = {}
assert(len(available_devices) > 0) # There should be at least 1 device, set in CGO_DEVICES
for idx, m in enumerate(logical_plan.models):
# models in one group should
# (1) not use more GPUs than available_devices
# (2) be connected in the logical plan (independent models should be assembled in multiple groups)
# (3) use same MultiModelSupervisedLearningModule
if len(group_model) > 0 and \
(AssemblePolicy._check_graph_connectivity(m, group_model, logical_plan) == False or
AssemblePolicy._check_evaluator(m, group_model) == False):
all_grouped_models.append(group_model)
group_model = {}
group_model[m] = available_devices[idx % len(available_devices)]
if len(group_model) == len(available_devices) or \
idx == len(logical_plan.models) - 1:
all_grouped_models.append(group_model)
group_model = {}
return all_grouped_models

Просмотреть файл

@ -1,402 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
__all__ = ['CGOExecutionEngine', 'TrialSubmission']
import logging
import os
import random
import string
import time
import threading
from typing import Iterable, List, Dict, Tuple, cast
from dataclasses import dataclass
from nni.common.device import GPUDevice, Device
from nni.experiment.config.training_services import RemoteConfig
from nni.nas import utils
from nni.nas.execution.common import (
AbstractExecutionEngine, AbstractGraphListener, WorkerInfo,
Model, ModelStatus, MetricData, Node,
RetiariiAdvisor, send_trial, receive_trial_parameters, get_advisor,
)
from nni.nas.execution.pytorch import codegen
from nni.nas.evaluator.pytorch.lightning import Lightning
from nni.nas.evaluator.pytorch.cgo.evaluator import _MultiModelSupervisedLearningModule
from nni.nas.execution.pytorch.graph import BaseGraphData
from .logical_optimizer.logical_plan import LogicalPlan, AbstractLogicalNode
from .logical_optimizer.opt_dedup_input import DedupInputOptimizer
_logger = logging.getLogger(__name__)
def _noop(*args, **kwargs):
pass
@dataclass
class TrialSubmission:
model: Model
placement: Dict[Node, Device]
grouped_models: List[Model]
class CGOExecutionEngine(AbstractExecutionEngine):
"""
The execution engine with Cross-Graph Optimization (CGO).
Only models using PyTorch Lighting and MultiModelSupervisedLearningModule as the evaluator can be optimized.
Otherwise, a model will be submitted independently without any cross-graph optimization.
Parameters
----------
training_service
The remote training service config.
max_concurrency
The maximum number of trials to run concurrently.
batch_waiting_time
Seconds to wait for each batch of trial submission.
The trials within one batch could apply cross-graph optimization.
rest_port
The port of the experiment's rest server
rest_url_prefix
The url prefix of the experiment's rest entry
"""
def __init__(self, training_service: RemoteConfig,
max_concurrency: int = None,
batch_waiting_time: int = 60,
rest_port: int | None = None,
rest_url_prefix: str | None = None
) -> None:
self.port = rest_port
self.url_prefix = rest_url_prefix
self._listeners: List[AbstractGraphListener] = []
self._running_models: Dict[int, Model] = dict()
self.logical_plan_counter = 0
self.available_devices: List[Device] = []
self.max_concurrency: int = max_concurrency
devices = self._construct_devices(training_service)
for device in devices:
self.available_devices.append(device)
self.all_devices = self.available_devices.copy()
self._batch_waiting_time = batch_waiting_time # seconds to wait for all models in a batch to do cross-graph optimization
self._optimizers = [DedupInputOptimizer()]
self._original_models = {}
self._original_model_to_multi_model = {}
self._trial_to_original_models = {}
self._trial_used_devices: Dict[int, List[Device]] = {}
self._history: List[Model] = []
self._queuing_models: List[Model] = []
self._models_to_retry: List[Model] = []
self._queue_lock = threading.Lock()
# register advisor callbacks
advisor: RetiariiAdvisor = get_advisor()
advisor.register_callbacks({
'send_trial': _noop,
'request_trial_jobs': _noop,
'trial_end': self._trial_end_callback,
'intermediate_metric': self._intermediate_metric_callback,
'final_metric': self._final_metric_callback
})
self._stopped = False
self._consumer_thread = threading.Thread(target=self._consume_models)
self._consumer_thread.start()
def _construct_devices(self, training_service):
devices = []
if hasattr(training_service, 'machine_list'):
for machine in cast(RemoteConfig, training_service).machine_list:
assert machine.gpu_indices is not None, \
'gpu_indices must be set in RemoteMachineConfig for CGO execution engine'
assert isinstance(machine.gpu_indices, list), 'gpu_indices must be a list'
for gpu_idx in machine.gpu_indices:
devices.append(GPUDevice(machine.host, gpu_idx))
return devices
def join(self):
self._stopped = True
self._consumer_thread.join()
def add_optimizer(self, opt):
self._optimizers.append(opt)
def submit_models(self, *models: List[Model]) -> None:
curr_time = time.time()
_logger.info('%d models are submitted', len(models))
self._queue_lock.acquire()
self._queuing_models.extend([(curr_time, _) for _ in models])
self._queue_lock.release()
def _submit_retry_models(self, models: List[Model]) -> None:
_logger.info('%d models are retried', len(models))
self._queue_lock.acquire()
self._models_to_retry.extend(models)
self._queue_lock.release()
def _consume_models(self):
# a thread to monitor self._models_to_retry and self._queuing_models to consume them in batch
while not self._stopped:
if len(self._models_to_retry) > 0:
self._queue_lock.acquire()
# retrying jobs should be first scheduled.
for m in self._models_to_retry:
if len(self.available_devices) > 0:
self._submit_models_in_batch(m) # submit the single model to avoid cross-graph optimization.
self._models_to_retry = self._models_to_retry[1:]
self._queue_lock.release()
if len(self._queuing_models) > 0:
self._queue_lock.acquire()
curr_time = time.time()
num_models_to_submit = len(self.available_devices)
if self.max_concurrency:
num_models_to_submit = min(num_models_to_submit, self.max_concurrency)
if curr_time - self._queuing_models[0][0] > self._batch_waiting_time:
num_models_to_submit = min(num_models_to_submit, len(self._queuing_models))
if num_models_to_submit > 0:
self._submit_models_in_batch(*[_[1] for _ in self._queuing_models[:num_models_to_submit]])
self._queuing_models = self._queuing_models[num_models_to_submit:]
self._queue_lock.release()
time.sleep(1)
def _extract_placement_constaint(self, placement_mapping: Dict[Node, Device]):
unique_gpus = sorted(list(set([e for e in placement_mapping.values() if isinstance(e, GPUDevice)])))
placement_constraint = None
if len(unique_gpus) > 0:
placement_constraint = {}
placement_constraint['type'] = 'Device'
placement_constraint['gpus'] = [(e.node_id, e.gpu_id) for e in unique_gpus]
return placement_constraint
def _submit_models_in_batch(self, *models: List[Model]) -> None:
_logger.info('%d models are submitted in batch', len(models))
_logger.debug('model id: %s', str([m.model_id for m in models]))
logical = self._build_logical(models)
for opt in self._optimizers:
opt.convert(logical)
phy_models_and_placements = self._assemble(logical)
for model, placement, grouped_models in phy_models_and_placements:
data = BaseGraphData(codegen.model_to_pytorch_script(model, placement=placement), model.evaluator, {})
placement_constraint = self._extract_placement_constaint(placement)
trial_id = send_trial(data.dump(), placement_constraint=placement_constraint)
# unique non-cpu devices used by the trial
self._trial_used_devices[trial_id] = list(set([_ for _ in placement.values() if isinstance(_, GPUDevice)]))
# currently, it is impossible for search strategy to submit models more than the number of available devices
for used_device in self._trial_used_devices[trial_id]:
self.available_devices.remove(used_device) # used_device must be in self.available_devices
self._running_models[trial_id] = model
self._trial_to_original_models[trial_id] = []
for m in grouped_models:
self._original_models[m.model_id] = m
self._original_model_to_multi_model[m.model_id] = model
self._trial_to_original_models[trial_id].append(m.model_id)
self._history.append(m)
def list_models(self) -> Iterable[Model]:
return self._history
def _assemble(self, logical_plan: LogicalPlan) -> List[Tuple[Model, Dict[Node, Device], List[Model]]]:
"""
Return the assembled models as a list of tuple.
Each tuple contains the assembled model, the device placement of graph nodes, and the original models.
"""
# try to use the available_devices first so that it can be launched as early as possible
# if free devices are not enough to assemble all models in one trial, try all devices
if len(self.available_devices) > 0:
grouped_models: List[Dict[Model, Device]] = AssemblePolicy().group(logical_plan, self.available_devices)
if len(self.available_devices) == 0 or len(grouped_models) > 1:
grouped_models: List[Dict[Model, Device]] = AssemblePolicy().group(logical_plan, self.all_devices)
phy_models_and_placements = []
for multi_model in grouped_models:
model, model_placement = logical_plan.assemble(multi_model)
assert isinstance(model.evaluator, Lightning), \
"cross-graph optimization only supports pytorch lighting as evaluator"
assert isinstance(model.evaluator.module, _MultiModelSupervisedLearningModule), \
"cross-graph optimization only support MultiModelSupervisedLearningModule"
# replace the module with a new instance whose n_models is set
# n_models must be set in __init__, otherwise it cannot be captured by serialize_cls
new_module_init_params = model.evaluator.module.dump_kwargs().copy()
# MultiModelSupervisedLearningModule hides n_models of _MultiModelSupervisedLearningModule from users
new_module_init_params['n_models'] = len(multi_model)
new_module = _MultiModelSupervisedLearningModule(**new_module_init_params)
model.evaluator.module = new_module
phy_models_and_placements.append((model, model_placement, multi_model.keys()))
return phy_models_and_placements
def _build_logical(self, models: List[Model]) -> LogicalPlan:
logical_plan = LogicalPlan(plan_id=self.logical_plan_counter)
for model in models:
logical_plan.add_model(model)
self.logical_plan_counter += 1
return logical_plan
def register_graph_listener(self, listener: AbstractGraphListener) -> None:
self._listeners.append(listener)
# def _send_trial_callback(self, paramater: dict) -> None:
# if len(self.available_devices) == 0:
# _logger.warning('There is no available devices, but trial is submitted.')
# _logger.debug('Resource used. Remaining: %d', len(self.available_devices))
# def _request_trial_jobs_callback(self, num_trials: int) -> None:
# self.resources += num_trials
# _logger.info('on_resource_available: %d', self.resources)
def _trial_end_callback(self, trial_id: int, success: bool) -> None:
model = self._running_models[trial_id]
if success:
model.status = ModelStatus.Trained
else:
model.status = ModelStatus.Failed
models_to_retry = []
for model_id in self._original_model_to_multi_model:
if self._original_model_to_multi_model[model_id] == model:
original_model = self._original_models[model_id]
if success:
original_model.status = ModelStatus.Trained
else:
original_model.status = ModelStatus.Failed
# the failed models in a multi-model will be retried one by one w/o CGO
if len(self._trial_to_original_models[trial_id]) > 1:
models_to_retry.append(original_model)
for listener in self._listeners:
listener.on_training_end(original_model, success)
if len(models_to_retry) > 0:
self._submit_retry_models(models_to_retry)
self.available_devices.extend(self._trial_used_devices[trial_id])
self.available_devices = sorted(list(set(self.available_devices)))
del self._running_models[trial_id]
def _intermediate_metric_callback(self, trial_id: int, metrics: MetricData) -> None:
merged_metrics = {}
for idx, _ in enumerate(metrics):
merged_metrics[self._trial_to_original_models[trial_id][idx]] = metrics[idx]
for model_id in merged_metrics:
self._original_models[model_id].intermediate_metrics.append(merged_metrics[model_id])
for listener in self._listeners:
listener.on_intermediate_metric(self._original_models[model_id], merged_metrics[model_id])
def _final_metric_callback(self, trial_id: int, metrics: MetricData) -> None:
_logger.debug(metrics)
if isinstance(metrics, float):
self._listeners[0].on_metric(self._running_models[trial_id], metrics)
else:
merged_metrics = {}
for idx, _ in enumerate(metrics):
merged_metrics[self._trial_to_original_models[trial_id][idx]] = metrics[idx]
for model_id in merged_metrics:
self._original_models[model_id].metric = merged_metrics[model_id]
for listener in self._listeners:
listener.on_metric(self._original_models[model_id], merged_metrics[model_id])
def query_available_resource(self) -> List[WorkerInfo]:
# the _queuing_models need to use available_devices first
self._queue_lock.acquire()
available_for_more_models = len(self.available_devices) - len(self._queuing_models) - len(self._models_to_retry)
self._queue_lock.release()
return available_for_more_models
def budget_exhausted(self) -> bool:
advisor = get_advisor()
return advisor.stopping
@classmethod
def trial_execute_graph(cls) -> None:
"""
Initialize the model, hand it over to trainer.
"""
graph_data = BaseGraphData.load(receive_trial_parameters())
_logger.info('CGO_ENGINE trial parameters received')
random_str = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(6))
file_name = f'_generated_model/{random_str}.py'
os.makedirs(os.path.dirname(file_name), exist_ok=True)
with open(file_name, 'w') as f:
f.write(graph_data.model_script)
trainer_instance = graph_data.evaluator
model_cls = utils.import_(f'_generated_model.{random_str}._model')
trainer_instance.fit(model_cls())
os.remove(file_name)
class AssemblePolicy:
@staticmethod
def _is_related_node(model: Model, node: Node):
if isinstance(node, AbstractLogicalNode):
if model in node.related_models:
return True
else:
if model == node.graph.model:
return True
return False
@staticmethod
def _check_graph_connectivity(model: Model,
group_model: Dict[Model, Device],
logical_plan: LogicalPlan) -> bool:
for edge in logical_plan.logical_graph.edges:
if AssemblePolicy._is_related_node(model, edge.head) or \
AssemblePolicy._is_related_node(model, edge.tail):
for grouped_model in group_model:
if AssemblePolicy._is_related_node(grouped_model, edge.head) or \
AssemblePolicy._is_related_node(grouped_model, edge.tail):
return True
return False
@staticmethod
def _check_evaluator(new_model: Model, group_model: Dict[Model, Device]) -> bool:
if not (isinstance(new_model.evaluator, Lightning)
and isinstance(new_model.evaluator.module, _MultiModelSupervisedLearningModule)):
return False
for m in group_model:
if not m.evaluator == new_model.evaluator:
return False
return True
@staticmethod
def group(logical_plan, available_devices):
# TODO: Packing multiple model in one GPU
# Currently, we only support one model per GPU
all_grouped_models = []
group_model = {}
assert(len(available_devices) > 0) # There should be at least 1 device, set in CGO_DEVICES
for idx, m in enumerate(logical_plan.models):
# models in one group should
# (1) not use more GPUs than available_devices
# (2) be connected in the logical plan (independent models should be assembled in multiple groups)
# (3) use same MultiModelSupervisedLearningModule
if len(group_model) > 0 and \
(AssemblePolicy._check_graph_connectivity(m, group_model, logical_plan) == False or
AssemblePolicy._check_evaluator(m, group_model) == False):
all_grouped_models.append(group_model)
group_model = {}
group_model[m] = available_devices[idx % len(available_devices)]
if len(group_model) == len(available_devices) or \
idx == len(logical_plan.models) - 1:
all_grouped_models.append(group_model)
group_model = {}
return all_grouped_models

Просмотреть файл

Просмотреть файл

@ -0,0 +1,41 @@
{
"_model": {
"inputs": ["image"],
"outputs": ["metric"],
"nodes": {
"stem": {"operation": {"type": "_cell", "cell_name": "stem"}},
"flatten": {"operation": {"type": "__torch__.torch.nn.Flatten"}},
"fc1": {"operation": {"type": "__torch__.torch.nn.Linear", "parameters": {"out_features": 256, "in_features": 1024}}},
"fc2": {"operation": {"type": "__torch__.torch.nn.Linear", "parameters": {"out_features": 10, "in_features": 256}}},
"softmax": {"operation": {"type": "__torch__.torch.nn.Softmax"}}
},
"edges": [
{"head": ["_inputs", 0], "tail": ["stem", null]},
{"head": ["stem", null], "tail": ["flatten", null]},
{"head": ["flatten", null], "tail": ["fc1", null]},
{"head": ["fc1", null], "tail": ["fc2", null]},
{"head": ["fc2", null], "tail": ["softmax", null]},
{"head": ["softmax", null], "tail": ["_outputs", 0]}
]
},
"stem": {
"nodes": {
"conv1": {"operation": {"type": "__torch__.torch.nn.Conv2d", "parameters": {"out_channels": 32, "in_channels": 1, "kernel_size": 5}}},
"pool1": {"operation": {"type": "__torch__.torch.nn.MaxPool2d", "parameters": {"kernel_size": 2}}},
"conv2": {"operation": {"type": "__torch__.torch.nn.Conv2d", "parameters": {"out_channels": 64, "in_channels": 32, "kernel_size": 5}}},
"pool2": {"operation": {"type": "__torch__.torch.nn.MaxPool2d", "parameters": {"kernel_size": 2}}}
},
"edges": [
{"head": ["_inputs", 0], "tail": ["conv1", null]},
{"head": ["conv1", null], "tail": ["pool1", null]},
{"head": ["pool1", null], "tail": ["conv2", null]},
{"head": ["conv2", null], "tail": ["pool2", null]},
{"head": ["pool2", null], "tail": ["_outputs", 0]}
]
}
}

Просмотреть файл

@ -0,0 +1,270 @@
from pathlib import Path
import pytest
import torch
import torch.nn as nn
import torchmetrics
from torchvision.datasets import MNIST
from torchvision import transforms
from pytorch_lightning.utilities.seed import seed_everything
import nni
from nni.experiment.config import RemoteConfig, RemoteMachineConfig
from nni.nas.evaluator.pytorch import Lightning, DataLoader
from nni.nas.execution import SequentialExecutionEngine
from nni.nas.execution.cgo import CrossGraphOptimization
from nni.nas.execution.cgo.evaluator import MultiModelLightningModule, MultiModelTrainer
from nni.nas.execution.cgo.logical_optimizer.logical_plan import LogicalPlan
from nni.nas.execution.cgo.logical_optimizer.opt_dedup_input import DedupInputOptimizer
from nni.nas.space import Node, ModelStatus
from nni.nas.space.pytorch import PytorchGraphModelSpace
from nni.runtime.trial_command_channel import get_default_trial_command_channel, set_default_trial_command_channel
from ut.sdk.helper.trial_command_channel import TestHelperTrialCommandChannel
class _model_cpu(nn.Module):
def __init__(self):
super().__init__()
self.M_1_stem = M_1_stem()
self.M_2_stem = M_2_stem()
self.M_1_flatten = torch.nn.Flatten()
self.M_2_flatten = torch.nn.Flatten()
self.M_1_fc1 = torch.nn.Linear(out_features=256, in_features=1024)
self.M_2_fc1 = torch.nn.Linear(out_features=256, in_features=1024)
self.M_1_fc2 = torch.nn.Linear(out_features=10, in_features=256)
self.M_2_fc2 = torch.nn.Linear(out_features=10, in_features=256)
self.M_1_softmax = torch.nn.Softmax()
self.M_2_softmax = torch.nn.Softmax()
def forward(self, *_inputs):
M_1__inputs_to_M_2_stem = _inputs[0]
M_1_stem = self.M_1_stem(_inputs[0])
M_2_stem = self.M_2_stem(M_1__inputs_to_M_2_stem)
M_1_flatten = self.M_1_flatten(M_1_stem)
M_2_flatten = self.M_2_flatten(M_2_stem)
M_1_fc1 = self.M_1_fc1(M_1_flatten)
M_2_fc1 = self.M_2_fc1(M_2_flatten)
M_1_fc2 = self.M_1_fc2(M_1_fc1)
M_2_fc2 = self.M_2_fc2(M_2_fc1)
M_1_softmax = self.M_1_softmax(M_1_fc2)
M_2_softmax = self.M_2_softmax(M_2_fc2)
return M_1_softmax, M_2_softmax
class _model_gpu(nn.Module):
def __init__(self):
super().__init__()
self.M_1_stem = M_1_stem().to('cuda:0')
self.M_2_stem = M_2_stem().to('cuda:1')
self.M_1_flatten = torch.nn.Flatten().to('cuda:0')
self.M_2_flatten = torch.nn.Flatten().to('cuda:1')
self.M_1_fc1 = torch.nn.Linear(out_features=256, in_features=1024).to('cuda:0')
self.M_2_fc1 = torch.nn.Linear(out_features=256, in_features=1024).to('cuda:1')
self.M_1_fc2 = torch.nn.Linear(out_features=10, in_features=256).to('cuda:0')
self.M_2_fc2 = torch.nn.Linear(out_features=10, in_features=256).to('cuda:1')
self.M_1_softmax = torch.nn.Softmax().to('cuda:0')
self.M_2_softmax = torch.nn.Softmax().to('cuda:1')
def forward(self, *_inputs):
M_1__inputs_to_M_1_stem = _inputs[0].to("cuda:0")
M_1__inputs_to_M_2_stem = _inputs[0].to("cuda:1")
M_1_stem = self.M_1_stem(M_1__inputs_to_M_1_stem)
M_2_stem = self.M_2_stem(M_1__inputs_to_M_2_stem)
M_1_flatten = self.M_1_flatten(M_1_stem)
M_2_flatten = self.M_2_flatten(M_2_stem)
M_1_fc1 = self.M_1_fc1(M_1_flatten)
M_2_fc1 = self.M_2_fc1(M_2_flatten)
M_1_fc2 = self.M_1_fc2(M_1_fc1)
M_2_fc2 = self.M_2_fc2(M_2_fc1)
M_1_softmax = self.M_1_softmax(M_1_fc2)
M_2_softmax = self.M_2_softmax(M_2_fc2)
return M_1_softmax, M_2_softmax
class M_1_stem(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(out_channels=32, in_channels=1, kernel_size=5)
self.pool1 = torch.nn.MaxPool2d(kernel_size=2)
self.conv2 = torch.nn.Conv2d(out_channels=64, in_channels=32, kernel_size=5)
self.pool2 = torch.nn.MaxPool2d(kernel_size=2)
def forward(self, *_inputs):
conv1 = self.conv1(_inputs[0])
pool1 = self.pool1(conv1)
conv2 = self.conv2(pool1)
pool2 = self.pool2(conv2)
return pool2
class M_2_stem(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(out_channels=32, in_channels=1, kernel_size=5)
self.pool1 = torch.nn.MaxPool2d(kernel_size=2)
self.conv2 = torch.nn.Conv2d(out_channels=64, in_channels=32, kernel_size=5)
self.pool2 = torch.nn.MaxPool2d(kernel_size=2)
def forward(self, *_inputs):
conv1 = self.conv1(_inputs[0])
pool1 = self.pool1(conv1)
conv2 = self.conv2(pool1)
pool2 = self.pool2(conv2)
return pool2
def create_evaluator(n_models=None):
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = nni.trace(MNIST)(root='data/mnist', train=True, download=False, transform=transform)
test_dataset = nni.trace(MNIST)(root='data/mnist', train=False, download=False, transform=transform)
multi_module = MultiModelLightningModule(
nn.CrossEntropyLoss(),
torchmetrics.Accuracy('multiclass', num_classes=10),
n_models=n_models
)
lightning = Lightning(
multi_module,
MultiModelTrainer(max_epochs=1, limit_train_batches=0.25, enable_progress_bar=True),
train_dataloaders=DataLoader(train_dataset, batch_size=100),
val_dataloaders=DataLoader(test_dataset, batch_size=100)
)
return lightning
def _load_mnist(n_models: int = 1):
path = Path(__file__).parent / 'mnist_pytorch.json'
with open(path) as f:
mnist_model = PytorchGraphModelSpace._load(**nni.load(fp=f))
mnist_model.evaluator = create_evaluator()
mnist_model.status = ModelStatus.Frozen
if n_models == 1:
return mnist_model
else:
models = [mnist_model]
for _ in range(n_models - 1):
forked_model = mnist_model.fork()
forked_model.status = ModelStatus.Frozen
models.append(forked_model)
return models
def _build_logical_with_mnist(n_models: int):
lp = LogicalPlan(model_cls=PytorchGraphModelSpace)
models = _load_mnist(n_models=n_models)
for m in models:
lp.add_model(m)
return lp, models
@pytest.fixture(autouse=True)
def seed():
seed_everything(42)
@pytest.fixture
def trial_command_channel():
_default_channel = get_default_trial_command_channel()
channel = TestHelperTrialCommandChannel()
set_default_trial_command_channel(channel)
nni.get_next_parameter()
yield channel
set_default_trial_command_channel(_default_channel)
@pytest.fixture(params=[1, 2, 4])
def cgo(request):
remote = RemoteConfig(machine_list=[])
remote.machine_list.append(RemoteMachineConfig(host='test', gpu_indices=list(range(request.param))))
cgo = CrossGraphOptimization(remote_config=remote, batch_waiting_time=0)
yield cgo
cgo.shutdown()
def test_multi_model_trainer_cpu(trial_command_channel):
evaluator = create_evaluator(n_models=2)
evaluator.evaluate(_model_cpu())
result = trial_command_channel.final
assert len(result) == 2
for _ in result:
assert _ > 0.8
@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='test requires GPU and torch+cuda')
def test_multi_model_trainer_gpu(trial_command_channel):
evaluator = create_evaluator(n_models=2)
evaluator.evaluate(_model_gpu())
result = trial_command_channel.final
assert len(result) == 2
for _ in result:
assert _ > 0.8
def test_add_model():
lp, models = _build_logical_with_mnist(3)
for node in lp.logical_graph.hidden_nodes:
old_nodes = [m.root_graph.get_node_by_id(node.id) for m in models]
assert any([old_nodes[0].__repr__() == Node.__repr__(x) for x in old_nodes])
def test_dedup_input(cgo):
lp, _ = _build_logical_with_mnist(3)
opt = DedupInputOptimizer()
opt.convert(lp)
phy_models = cgo._assemble(lp)
if len(cgo.available_devices) == 4:
assert len(list(phy_models)) == 1
elif len(cgo.available_devices) == 2:
assert len(list(phy_models)) == 2
elif len(cgo.available_devices) == 1:
assert len(list(phy_models)) == 3
else:
raise ValueError(f'Invalid device count: {cgo.available_devices}')
cgo.shutdown()
def test_submit_models(cgo):
import logging
logging.getLogger('nni.nas.execution.sequential').setLevel(logging.DEBUG)
models = _load_mnist(2)
engine = SequentialExecutionEngine(continue_on_failure=True)
cgo.set_engine(engine)
cgo.submit_models(*models)
cgo.wait_models()
if not torch.cuda.is_available():
for model in models: # can't be trained without gpu.
assert model.status == ModelStatus.Failed
if len(cgo.available_devices) == 1:
assert engine._model_count == 2 # 2 single
else:
assert engine._model_count == 3 # 1 + retry 2
elif torch.cuda.device_count() == 1 and len(cgo.available_devices) == 1:
# Should be the case on pipeline.
assert engine._model_count == 2 # No merge at all.
for model in models:
assert model.status == ModelStatus.Trained
assert model.metrics.final > 0.8

Просмотреть файл

@ -1,366 +0,0 @@
import os
import threading
import unittest
import time
import torch
import torch.nn as nn
from pytorch_lightning.utilities.seed import seed_everything
from pathlib import Path
import nni
from nni.experiment.config import RemoteConfig, RemoteMachineConfig
from nni.runtime.tuner_command_channel import legacy as protocol
import json
try:
from nni.common.device import GPUDevice
from nni.retiarii.execution.cgo_engine import CGOExecutionEngine
from nni.retiarii import Model
from nni.retiarii.graph import Node
from nni.retiarii import Model, submit_models
from nni.retiarii.integration import RetiariiAdvisor
from nni.retiarii.execution import set_execution_engine
from nni.retiarii.execution.logical_optimizer.opt_dedup_input import DedupInputOptimizer
from nni.retiarii.execution.logical_optimizer.logical_plan import LogicalPlan
from nni.retiarii.utils import import_
from nni.retiarii import serialize
import nni.retiarii.evaluator.pytorch.lightning as pl
from nni.retiarii.evaluator.pytorch.cgo.evaluator import MultiModelSupervisedLearningModule, _MultiModelSupervisedLearningModule
import nni.retiarii.evaluator.pytorch.cgo.trainer as cgo_trainer
import nni.retiarii.integration_api
module_import_failed = False
except ImportError:
module_import_failed = True
import pytest
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import Dataset
from sklearn.datasets import load_diabetes
pytestmark = pytest.mark.skip(reason='Will be rewritten.')
class _model_cpu(nn.Module):
def __init__(self):
super().__init__()
self.M_1_stem = M_1_stem()
self.M_2_stem = M_2_stem()
self.M_1_flatten = torch.nn.Flatten()
self.M_2_flatten = torch.nn.Flatten()
self.M_1_fc1 = torch.nn.Linear(out_features=256, in_features=1024)
self.M_2_fc1 = torch.nn.Linear(out_features=256, in_features=1024)
self.M_1_fc2 = torch.nn.Linear(out_features=10, in_features=256)
self.M_2_fc2 = torch.nn.Linear(out_features=10, in_features=256)
self.M_1_softmax = torch.nn.Softmax()
self.M_2_softmax = torch.nn.Softmax()
def forward(self, *_inputs):
M_1__inputs_to_M_2_stem = _inputs[0]
M_1_stem = self.M_1_stem(_inputs[0])
M_2_stem = self.M_2_stem(M_1__inputs_to_M_2_stem)
M_1_flatten = self.M_1_flatten(M_1_stem)
M_2_flatten = self.M_2_flatten(M_2_stem)
M_1_fc1 = self.M_1_fc1(M_1_flatten)
M_2_fc1 = self.M_2_fc1(M_2_flatten)
M_1_fc2 = self.M_1_fc2(M_1_fc1)
M_2_fc2 = self.M_2_fc2(M_2_fc1)
M_1_softmax = self.M_1_softmax(M_1_fc2)
M_2_softmax = self.M_2_softmax(M_2_fc2)
return M_1_softmax, M_2_softmax
class _model_gpu(nn.Module):
def __init__(self):
super().__init__()
self.M_1_stem = M_1_stem().to('cuda:0')
self.M_2_stem = M_2_stem().to('cuda:1')
self.M_1_flatten = torch.nn.Flatten().to('cuda:0')
self.M_2_flatten = torch.nn.Flatten().to('cuda:1')
self.M_1_fc1 = torch.nn.Linear(out_features=256, in_features=1024).to('cuda:0')
self.M_2_fc1 = torch.nn.Linear(out_features=256, in_features=1024).to('cuda:1')
self.M_1_fc2 = torch.nn.Linear(out_features=10, in_features=256).to('cuda:0')
self.M_2_fc2 = torch.nn.Linear(out_features=10, in_features=256).to('cuda:1')
self.M_1_softmax = torch.nn.Softmax().to('cuda:0')
self.M_2_softmax = torch.nn.Softmax().to('cuda:1')
def forward(self, *_inputs):
M_1__inputs_to_M_1_stem = _inputs[0].to("cuda:0")
M_1__inputs_to_M_2_stem = _inputs[0].to("cuda:1")
M_1_stem = self.M_1_stem(M_1__inputs_to_M_1_stem)
M_2_stem = self.M_2_stem(M_1__inputs_to_M_2_stem)
M_1_flatten = self.M_1_flatten(M_1_stem)
M_2_flatten = self.M_2_flatten(M_2_stem)
M_1_fc1 = self.M_1_fc1(M_1_flatten)
M_2_fc1 = self.M_2_fc1(M_2_flatten)
M_1_fc2 = self.M_1_fc2(M_1_fc1)
M_2_fc2 = self.M_2_fc2(M_2_fc1)
M_1_softmax = self.M_1_softmax(M_1_fc2)
M_2_softmax = self.M_2_softmax(M_2_fc2)
return M_1_softmax, M_2_softmax
class M_1_stem(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(out_channels=32, in_channels=1, kernel_size=5)
self.pool1 = torch.nn.MaxPool2d(kernel_size=2)
self.conv2 = torch.nn.Conv2d(out_channels=64, in_channels=32, kernel_size=5)
self.pool2 = torch.nn.MaxPool2d(kernel_size=2)
def forward(self, *_inputs):
conv1 = self.conv1(_inputs[0])
pool1 = self.pool1(conv1)
conv2 = self.conv2(pool1)
pool2 = self.pool2(conv2)
return pool2
class M_2_stem(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(out_channels=32, in_channels=1, kernel_size=5)
self.pool1 = torch.nn.MaxPool2d(kernel_size=2)
self.conv2 = torch.nn.Conv2d(out_channels=64, in_channels=32, kernel_size=5)
self.pool2 = torch.nn.MaxPool2d(kernel_size=2)
def forward(self, *_inputs):
conv1 = self.conv1(_inputs[0])
pool1 = self.pool1(conv1)
conv2 = self.conv2(pool1)
pool2 = self.pool2(conv2)
return pool2
def _reset():
# this is to not affect other tests in sdk
nni.trial._intermediate_seq = 0
nni.trial._params = {'foo': 'bar', 'parameter_id': 0, 'parameters': {}}
nni.runtime.platform.test._last_metric = None
nni.retiarii.integration_api._advisor = None
nni.retiarii.execution.api._execution_engine = None
seed_everything(42)
def _new_trainer():
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = serialize(MNIST, root='data/mnist', train=True, download=True, transform=transform)
test_dataset = serialize(MNIST, root='data/mnist', train=False, download=True, transform=transform)
multi_module = _MultiModelSupervisedLearningModule(nn.CrossEntropyLoss, {'acc': pl.AccuracyWithLogits})
lightning = pl.Lightning(multi_module, cgo_trainer.Trainer(use_cgo=True,
max_epochs=1,
limit_train_batches=0.25,
enable_progress_bar=False),
train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100))
return lightning
def _load_mnist(n_models: int = 1):
path = Path('ut/nas/mnist_pytorch.json')
with open(path) as f:
mnist_model = Model._load(nni.load(fp=f))
mnist_model.evaluator = _new_trainer()
if n_models == 1:
return mnist_model
else:
models = [mnist_model]
for i in range(n_models - 1):
forked_model = mnist_model.fork()
forked_model.evaluator = _new_trainer()
models.append(forked_model)
return models
def _get_final_result():
result = nni.load(nni.runtime.platform.test._last_metric)['value']
if isinstance(result, list):
return [float(_) for _ in result]
else:
if isinstance(result, str) and '[' in result:
return nni.load(result)
return [float(result)]
class CGOEngineTest(unittest.TestCase):
def setUp(self):
if module_import_failed:
self.skipTest('test skip due to failed import of nni.retiarii.evaluator.pytorch.lightning')
def test_multi_model_trainer_cpu(self):
_reset()
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = serialize(MNIST, root='data/mnist', train=True, download=True, transform=transform)
test_dataset = serialize(MNIST, root='data/mnist', train=False, download=True, transform=transform)
multi_module = _MultiModelSupervisedLearningModule(nn.CrossEntropyLoss, {'acc': pl.AccuracyWithLogits}, n_models=2)
lightning = pl.Lightning(multi_module, cgo_trainer.Trainer(use_cgo=True,
max_epochs=1,
limit_train_batches=0.25),
train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100))
lightning._execute(_model_cpu)
result = _get_final_result()
assert len(result) == 2
for _ in result:
assert _ > 0.8
def test_multi_model_trainer_gpu(self):
_reset()
if not (torch.cuda.is_available() and torch.cuda.device_count() >= 2):
pytest.skip('test requires GPU and torch+cuda')
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = serialize(MNIST, root='data/mnist', train=True, download=True, transform=transform)
test_dataset = serialize(MNIST, root='data/mnist', train=False, download=True, transform=transform)
multi_module = _MultiModelSupervisedLearningModule(nn.CrossEntropyLoss, {'acc': pl.AccuracyWithLogits}, n_models=2)
lightning = pl.Lightning(multi_module, cgo_trainer.Trainer(use_cgo=True,
max_epochs=1,
limit_train_batches=0.25),
train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100))
lightning._execute(_model_gpu)
result = _get_final_result()
assert len(result) == 2
for _ in result:
assert _ > 0.8
def _build_logical_with_mnist(self, n_models: int):
lp = LogicalPlan()
models = _load_mnist(n_models=n_models)
for m in models:
lp.add_model(m)
return lp, models
def test_add_model(self):
_reset()
lp, models = self._build_logical_with_mnist(3)
for node in lp.logical_graph.hidden_nodes:
old_nodes = [m.root_graph.get_node_by_id(node.id) for m in models]
self.assertTrue(any([old_nodes[0].__repr__() == Node.__repr__(x) for x in old_nodes]))
def test_dedup_input_four_devices(self):
_reset()
lp, models = self._build_logical_with_mnist(3)
opt = DedupInputOptimizer()
opt.convert(lp)
advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
advisor._channel = protocol.LegacyCommandChannel()
advisor.default_worker.start()
advisor.assessor_worker.start()
remote = RemoteConfig(machine_list=[])
remote.machine_list.append(RemoteMachineConfig(host='test', gpu_indices=[0,1,2,3]))
cgo = CGOExecutionEngine(training_service=remote, batch_waiting_time=0)
phy_models = cgo._assemble(lp)
self.assertTrue(len(phy_models) == 1)
advisor.stopping = True
advisor.default_worker.join()
advisor.assessor_worker.join()
cgo.join()
def test_dedup_input_two_devices(self):
_reset()
lp, models = self._build_logical_with_mnist(3)
opt = DedupInputOptimizer()
opt.convert(lp)
advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
advisor._channel = protocol.LegacyCommandChannel()
advisor.default_worker.start()
advisor.assessor_worker.start()
remote = RemoteConfig(machine_list=[])
remote.machine_list.append(RemoteMachineConfig(host='test', gpu_indices=[0,1]))
cgo = CGOExecutionEngine(training_service=remote, batch_waiting_time=0)
phy_models = cgo._assemble(lp)
self.assertTrue(len(phy_models) == 2)
advisor.stopping = True
advisor.default_worker.join()
advisor.assessor_worker.join()
cgo.join()
def test_submit_models(self):
_reset()
os.makedirs('generated', exist_ok=True)
import nni.runtime.platform.test as tt
protocol._set_out_file(open('generated/debug_protocol_out_file.py', 'wb'))
protocol._set_in_file(open('generated/debug_protocol_out_file.py', 'rb'))
models = _load_mnist(2)
advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
advisor._channel = protocol.LegacyCommandChannel()
advisor.default_worker.start()
advisor.assessor_worker.start()
# this is because RetiariiAdvisor only works after `_advisor_initialized` becomes True.
# normally it becomes true when `handle_request_trial_jobs` is invoked
advisor._advisor_initialized = True
remote = RemoteConfig(machine_list=[])
remote.machine_list.append(RemoteMachineConfig(host='test', gpu_indices=[0,1,2,3]))
cgo_engine = CGOExecutionEngine(training_service=remote, batch_waiting_time=0)
set_execution_engine(cgo_engine)
submit_models(*models)
time.sleep(3)
if torch.cuda.is_available() and torch.cuda.device_count() >= 2:
cmd, data = protocol.receive()
params = nni.load(data)
tt.init_params(params)
trial_thread = threading.Thread(target=CGOExecutionEngine.trial_execute_graph)
trial_thread.start()
last_metric = None
while True:
time.sleep(1)
if tt._last_metric:
metric = tt.get_last_metric()
if metric == last_metric:
continue
if 'value' in metric:
metric['value'] = json.dumps(metric['value'])
advisor.handle_report_metric_data(metric)
last_metric = metric
if not trial_thread.is_alive():
trial_thread.join()
break
trial_thread.join()
advisor.stopping = True
advisor.default_worker.join()
advisor.assessor_worker.join()
cgo_engine.join()
if __name__ == '__main__':
unittest.main()