зеркало из https://github.com/microsoft/nni.git
NAS execution engine (stage 3) - CGO (#5360)
This commit is contained in:
Родитель
1e4f3f08d3
Коммит
67a61e1d94
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .engine import *
|
||||
from .middleware import CrossGraphOptimization
|
|
@ -0,0 +1,148 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = ['MultiModelLightningModule', 'MultiModelTrainer']
|
||||
|
||||
import pytorch_lightning
|
||||
import torch
|
||||
from pytorch_lightning.strategies import SingleDeviceStrategy
|
||||
from torch import nn
|
||||
from torchmetrics import Metric
|
||||
|
||||
import nni
|
||||
from nni.nas.evaluator.pytorch.lightning import LightningModule
|
||||
|
||||
|
||||
class MultiModelLightningModule(LightningModule):
|
||||
"""The lightning module for a merged "multi-model".
|
||||
|
||||
The output of the multi-model is expected to be a tuple of tensors.
|
||||
The tensors will be each passed to a criterion and a metric.
|
||||
The loss will be added up for back propagation, and the metrics will be logged.
|
||||
|
||||
The reported metric will be a list of metrics, one for each model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
criterion
|
||||
Loss function.
|
||||
metric
|
||||
Metric function.
|
||||
n_models
|
||||
Number of models in the multi-model.
|
||||
"""
|
||||
|
||||
def __init__(self, criterion: nn.Module, metric: Metric, n_models: int | None = None):
|
||||
super().__init__()
|
||||
self.criterion = criterion
|
||||
self.metric = metric
|
||||
self.n_models = n_models
|
||||
|
||||
def _dump(self) -> dict:
|
||||
return {
|
||||
'criterion': self.criterion,
|
||||
'metric': self.metric,
|
||||
'n_models': self.n_models,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _load(criterion: nn.Module, metric: Metric, n_models: int | None = None) -> MultiModelLightningModule:
|
||||
return MultiModelLightningModule(criterion, metric, n_models)
|
||||
|
||||
def forward(self, x):
|
||||
y_hat = self.model(x)
|
||||
return y_hat
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
multi_y_hat = self(x)
|
||||
if isinstance(multi_y_hat, tuple):
|
||||
assert len(multi_y_hat) == self.n_models
|
||||
else:
|
||||
assert self.n_models == 1
|
||||
multi_y_hat = [multi_y_hat]
|
||||
multi_loss = []
|
||||
for idx, y_hat in enumerate(multi_y_hat):
|
||||
loss = self.criterion(y_hat.to("cpu"), y.to("cpu"))
|
||||
self.log(f'train_loss_{idx}', loss, prog_bar=True)
|
||||
self.log(f'train_metric_{idx}', self.metric(y_hat.to("cpu"), y.to("cpu")), prog_bar=True)
|
||||
multi_loss.append(loss)
|
||||
return sum(multi_loss)
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
multi_y_hat = self(x)
|
||||
if isinstance(multi_y_hat, tuple):
|
||||
assert len(multi_y_hat) == self.n_models
|
||||
else:
|
||||
assert self.n_models == 1
|
||||
multi_y_hat = [multi_y_hat]
|
||||
for idx, y_hat in enumerate(multi_y_hat):
|
||||
self.log(f'val_loss_{idx}', self.criterion(y_hat.to("cpu"), y.to("cpu")), prog_bar=True)
|
||||
self.log(f'val_metric_{idx}', self.metric(y_hat.to("cpu"), y.to("cpu")), prog_bar=True)
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
multi_y_hat = self(x)
|
||||
if isinstance(multi_y_hat, tuple):
|
||||
assert len(multi_y_hat) == self.n_models
|
||||
else:
|
||||
assert self.n_models == 1
|
||||
multi_y_hat = [multi_y_hat]
|
||||
for idx, y_hat in enumerate(multi_y_hat):
|
||||
self.log(f'test_loss_{idx}', self.criterion(y_hat.to("cpu"), y.to("cpu")), prog_bar=True)
|
||||
self.log(f'test_metric_{idx}', self.metric(y_hat.to("cpu"), y.to("cpu")), prog_bar=True)
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||
|
||||
def on_validation_epoch_end(self):
|
||||
nni.report_intermediate_result(self._get_validation_metrics())
|
||||
|
||||
def teardown(self, stage):
|
||||
if stage == 'fit':
|
||||
nni.report_final_result(self._get_validation_metrics())
|
||||
|
||||
def _get_validation_metrics(self):
|
||||
# TODO: split metric of multiple models?
|
||||
assert self.n_models is not None
|
||||
return [self.trainer.callback_metrics[f'val_metric_{idx}'].item() for idx in range(self.n_models)]
|
||||
|
||||
|
||||
class _BypassStrategy(SingleDeviceStrategy):
|
||||
strategy_name = "single_device"
|
||||
|
||||
def model_to_device(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@nni.trace
|
||||
class MultiModelTrainer(pytorch_lightning.Trainer):
|
||||
"""
|
||||
Trainer for cross-graph optimization.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
use_cgo
|
||||
Whether cross-graph optimization (CGO) is used.
|
||||
If it is True, CGO will manage device placement.
|
||||
Any device placement from pytorch lightning will be bypassed.
|
||||
default: False
|
||||
trainer_kwargs
|
||||
Optional keyword arguments passed to trainer. See
|
||||
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
|
||||
"""
|
||||
|
||||
def __init__(self, use_cgo: bool = True, **trainer_kwargs):
|
||||
if use_cgo:
|
||||
if "accelerator" in trainer_kwargs:
|
||||
raise ValueError("accelerator should not be set when cross-graph optimization is enabled.")
|
||||
|
||||
if 'strategy' in trainer_kwargs:
|
||||
raise ValueError("MultiModelTrainer does not support specifying strategy")
|
||||
|
||||
trainer_kwargs['strategy'] = _BypassStrategy()
|
||||
|
||||
super().__init__(**trainer_kwargs)
|
|
@ -2,13 +2,13 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
import copy
|
||||
from typing import Dict, Tuple, Any
|
||||
from typing import Dict, Tuple, Any, Type
|
||||
|
||||
from nni.retiarii.utils import uid
|
||||
from nni.common.device import Device, CPUDevice
|
||||
from nni.mutable.utils import uid
|
||||
|
||||
from nni.nas.execution.common.graph import Cell, Edge, Graph, Model, Node
|
||||
from nni.nas.execution.common.graph_op import Operation, _IOPseudoOperation
|
||||
from nni.nas.space import Edge, Graph, GraphModelSpace, Node
|
||||
from nni.nas.space.graph_op import Cell, Operation, _IOPseudoOperation
|
||||
|
||||
|
||||
class AbstractLogicalNode(Node):
|
||||
|
@ -16,7 +16,7 @@ class AbstractLogicalNode(Node):
|
|||
super().__init__(graph, node_id, name, operation, _internal=_internal)
|
||||
self.related_models = []
|
||||
|
||||
def assemble(self, multi_model_placement: Dict[Model, Device]) -> Tuple[Node, Device]:
|
||||
def assemble(self, multi_model_placement: Dict[GraphModelSpace, Device]) -> Tuple[Node, Device]:
|
||||
"""
|
||||
Given a set of models to be formed in a physical model and their device placement,
|
||||
this function replaces the logical node with an executable physical node for the physical model.
|
||||
|
@ -42,7 +42,7 @@ class AbstractLogicalNode(Node):
|
|||
|
||||
|
||||
class LogicalGraph(Graph):
|
||||
def __init__(self, model: Model, graph_id: int, name: str = None, _internal: bool = False):
|
||||
def __init__(self, model: GraphModelSpace, graph_id: int, name: str = None, _internal: bool = False):
|
||||
super().__init__(model, graph_id, name='logical_' + name, _internal=_internal)
|
||||
|
||||
def _dump(self) -> Any:
|
||||
|
@ -71,7 +71,7 @@ class LogicalGraph(Graph):
|
|||
'edges': edges_dump
|
||||
}
|
||||
|
||||
def _fork_to(self, model: Model) -> Graph:
|
||||
def _fork_to(self, model: GraphModelSpace) -> Graph:
|
||||
new_graph = Graph(model, self.id, self.name,
|
||||
_internal=True)._register()
|
||||
|
||||
|
@ -106,7 +106,7 @@ class OriginNode(AbstractLogicalNode):
|
|||
self.original_graph = original_graph
|
||||
self.original_node = original_node
|
||||
|
||||
def assemble(self, multi_model_placement: Dict[Model, Device]) -> Tuple[Node, Device]:
|
||||
def assemble(self, multi_model_placement: Dict[GraphModelSpace, Device]) -> Tuple[Node, Device]:
|
||||
model_id = self.original_node.graph.model.model_id
|
||||
new_node = Node(self.original_node.graph, self.original_node.id,
|
||||
f"M_{model_id}_" +
|
||||
|
@ -124,15 +124,17 @@ class OriginNode(AbstractLogicalNode):
|
|||
|
||||
|
||||
class LogicalPlan:
|
||||
def __init__(self, plan_id=0) -> None:
|
||||
self.lp_model = Model(_internal=True)
|
||||
def __init__(self, model_cls: Type[GraphModelSpace], plan_id: int = 0) -> None:
|
||||
# GraphModelSpace has multiple implementations based on the framework.
|
||||
self.model_cls = model_cls
|
||||
self.lp_model = model_cls(_internal=True)
|
||||
self.id = plan_id
|
||||
self.logical_graph = LogicalGraph(
|
||||
self.lp_model, self.id, name=f'{self.id}', _internal=True)._register()
|
||||
self.lp_model._root_graph_name = self.logical_graph.name
|
||||
self.models = []
|
||||
|
||||
def add_model(self, model: Model):
|
||||
def add_model(self, model: GraphModelSpace):
|
||||
self.models.append(model)
|
||||
# Only optimize the root graph.
|
||||
self._merge_graph(model.root_graph)
|
||||
|
@ -152,8 +154,8 @@ class LogicalPlan:
|
|||
new_tail = id_to_new_node[edge.tail.id]
|
||||
Edge((new_head, edge.head_slot), (new_tail, edge.tail_slot), _internal=True)._register()
|
||||
|
||||
def assemble(self, multi_model_placement: Dict[Model, Device]) \
|
||||
-> Tuple[Model, Dict[Node, Device]]:
|
||||
def assemble(self, multi_model_placement: Dict[GraphModelSpace, Device]) \
|
||||
-> Tuple[GraphModelSpace, Dict[Node, Device]]:
|
||||
"""
|
||||
Given a set of models to be formed in a physical model and their device placement,
|
||||
this function replaces all the logical node in this LogicalPlan with executable physical nodes
|
||||
|
@ -167,13 +169,13 @@ class LogicalPlan:
|
|||
|
||||
Returns
|
||||
-------
|
||||
phy_model : Model
|
||||
phy_model : GraphModelSpace
|
||||
the physical model formed by models in `multi_model_placement`
|
||||
all logical node are replaced by physical nodes
|
||||
node_placements : dict
|
||||
the device placement of the nodes in `phy_model`
|
||||
"""
|
||||
phy_model = Model(_internal=True)
|
||||
phy_model = self.model_cls(_internal=True)
|
||||
phy_graph = self.lp_model.root_graph._fork_to(phy_model)
|
||||
phy_graph._rename_graph(phy_graph.name, "_model")
|
||||
|
|
@ -3,17 +3,12 @@
|
|||
|
||||
from typing import List, Dict, Tuple
|
||||
|
||||
from nni.nas.utils import uid
|
||||
from nni.nas.evaluator.pytorch.cgo.evaluator import MultiModelSupervisedLearningModule
|
||||
from nni.mutable.utils import uid
|
||||
from nni.common.device import GPUDevice
|
||||
|
||||
from nni.nas.execution.common.graph import Graph, Model, Node
|
||||
from nni.nas.space import GraphModelSpace, Graph, Node
|
||||
from .interface import AbstractOptimizer
|
||||
from .logical_plan import (AbstractLogicalNode, LogicalGraph, LogicalPlan,
|
||||
OriginNode)
|
||||
|
||||
|
||||
_supported_evaluators = [MultiModelSupervisedLearningModule]
|
||||
from .logical_plan import AbstractLogicalNode, LogicalGraph, LogicalPlan, OriginNode
|
||||
|
||||
|
||||
class DedupInputNode(AbstractLogicalNode):
|
||||
|
@ -23,7 +18,6 @@ class DedupInputNode(AbstractLogicalNode):
|
|||
These models will share the result of once calculation.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, logical_graph: LogicalGraph, node_id: int,
|
||||
nodes_to_dedup: List[Node], _internal=False):
|
||||
super().__init__(logical_graph, node_id,
|
||||
|
@ -32,7 +26,7 @@ class DedupInputNode(AbstractLogicalNode):
|
|||
self.origin_nodes: List[OriginNode] = nodes_to_dedup.copy()
|
||||
self.related_models = [_.original_graph.model for _ in self.origin_nodes]
|
||||
|
||||
def assemble(self, multi_model_placement: Dict[Model, GPUDevice]) -> Tuple[Node, GPUDevice]:
|
||||
def assemble(self, multi_model_placement: Dict[GraphModelSpace, GPUDevice]) -> Tuple[Node, GPUDevice]:
|
||||
for node in self.origin_nodes:
|
||||
if node.original_graph.model in multi_model_placement:
|
||||
new_node = Node(node.original_graph, node.id,
|
||||
|
@ -54,10 +48,10 @@ class DedupInputOptimizer(AbstractOptimizer):
|
|||
pass
|
||||
|
||||
def _check_supported_evaluator(self, evaluator):
|
||||
for e in _supported_evaluators:
|
||||
if isinstance(evaluator, e):
|
||||
return True
|
||||
return False
|
||||
# NOTE(yuge): I think this is buggy. But I'm not sure whether I should fix it.
|
||||
from nni.nas.execution.cgo.evaluator import MultiModelLightningModule
|
||||
_supported_evaluators = (MultiModelLightningModule, )
|
||||
return isinstance(evaluator, _supported_evaluators)
|
||||
|
||||
def _check_deduplicate_by_node(self, root_node, node_to_check):
|
||||
if root_node == node_to_check:
|
|
@ -0,0 +1,394 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = ['CrossGraphOptimization']
|
||||
|
||||
import logging
|
||||
import time
|
||||
import threading
|
||||
from collections.abc import Iterable
|
||||
from typing import List, Dict, Tuple, cast
|
||||
|
||||
from nni.common.device import GPUDevice, Device
|
||||
from nni.experiment.config.training_services import RemoteConfig
|
||||
from nni.nas.space import GraphModelSpace, Node, ModelStatus, ExecutableModelSpace
|
||||
from nni.nas.execution.engine import Middleware, ExecutionEngine
|
||||
from nni.nas.execution.event import ModelEventType, IntermediateMetricEvent, FinalMetricEvent, TrainingEndEvent
|
||||
from nni.typehint import TrialMetric
|
||||
|
||||
from .logical_optimizer.logical_plan import LogicalPlan, AbstractLogicalNode
|
||||
from .logical_optimizer.opt_dedup_input import DedupInputOptimizer
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CrossGraphOptimization(Middleware):
|
||||
"""
|
||||
The execution engine middleware of Cross-Graph Optimization (CGO).
|
||||
It's a technique that merges multiple models into one model for training speedup.
|
||||
See `Retiarii paper <https://www.usenix.org/system/files/osdi20-zhang_quanlu.pdf>`__ for details.
|
||||
|
||||
Currently, :class:`CrossGraphOptimization` is only a prototype.
|
||||
It's not fully tested, and also, comes with a bunch of constraints on the model space and evaluator:
|
||||
|
||||
- The models must be in the format of :class:`~nni.nas.space.GraphModelSpace`.
|
||||
- The evaluator has to be a :class:`~nni.nas.evaluator.pytorch.Lightning` evaluator.
|
||||
- The ``lightning_module`` argument of the evaluator must be an instance of
|
||||
:class:`~nni.nas.execution.cgo.evaluator.MultiModelSupervisedLearningModule`.
|
||||
- The ``trainer`` argument of the evaluator must be an instance of
|
||||
:class:`~nni.nas.execution.cgo.evaluator.MultiModelTrainer`.
|
||||
|
||||
There are also a number of limitations:
|
||||
|
||||
- CGO doesn't support stop and resume a checkpoint.
|
||||
- Only remote training service is supported.
|
||||
- All model history are stored in memory. The experiment might not scale well.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
remote_config
|
||||
The remote training service config.
|
||||
max_concurrency
|
||||
The maximum number of trials to run concurrently.
|
||||
batch_waiting_time
|
||||
Seconds to wait for each batch of trial submission.
|
||||
The trials within one batch could apply cross-graph optimization.
|
||||
"""
|
||||
|
||||
def __init__(self, remote_config: RemoteConfig,
|
||||
max_concurrency: int | None = None,
|
||||
batch_waiting_time: int = 60) -> None:
|
||||
super().__init__()
|
||||
|
||||
_logger.warning('Cross graph optimization is an experimental feature. Usages are subject to change.')
|
||||
|
||||
self._history: List[GraphModelSpace] = []
|
||||
self._running_models: Dict[int, GraphModelSpace] = {}
|
||||
|
||||
self.logical_plan_counter = 0
|
||||
self.available_devices: List[Device] = []
|
||||
self.max_concurrency: int | None = max_concurrency
|
||||
|
||||
devices = self._construct_devices(remote_config)
|
||||
for device in devices:
|
||||
self.available_devices.append(device)
|
||||
self.all_devices = self.available_devices.copy()
|
||||
|
||||
self._batch_waiting_time = batch_waiting_time # seconds to wait for all models in a batch to do cross-graph optimization
|
||||
self._optimizers = [DedupInputOptimizer()]
|
||||
self._original_models: Dict[int, GraphModelSpace] = {}
|
||||
self._original_model_to_multi_model: Dict[int, GraphModelSpace] = {}
|
||||
self._trial_to_original_models: Dict[int, List[GraphModelSpace]] = {}
|
||||
self._trial_used_devices: Dict[int, List[Device]] = {}
|
||||
|
||||
self._queuing_models: List[GraphModelSpace] = []
|
||||
self._models_to_retry: List[GraphModelSpace] = []
|
||||
self._queue_lock = threading.Lock()
|
||||
|
||||
self._stopped = False
|
||||
self._consumer_thread = threading.Thread(target=self._consume_models)
|
||||
self._consumer_thread.start()
|
||||
|
||||
def _construct_devices(self, training_service):
|
||||
devices = []
|
||||
if hasattr(training_service, 'machine_list'):
|
||||
for machine in cast(RemoteConfig, training_service).machine_list:
|
||||
assert machine.gpu_indices is not None, \
|
||||
'gpu_indices must be set in RemoteMachineConfig for CGO execution engine'
|
||||
assert isinstance(machine.gpu_indices, list), 'gpu_indices must be a list'
|
||||
for gpu_idx in machine.gpu_indices:
|
||||
devices.append(GPUDevice(machine.host, gpu_idx))
|
||||
return devices
|
||||
|
||||
def shutdown(self):
|
||||
self._stopped = True
|
||||
self._consumer_thread.join()
|
||||
|
||||
self.engine.unregister_model_event_callback(ModelEventType.TrainingEnd, self._training_end_callback)
|
||||
self.engine.unregister_model_event_callback(ModelEventType.FinalMetric, self._final_metric_callback)
|
||||
self.engine.unregister_model_event_callback(ModelEventType.IntermediateMetric, self._intermediate_metric_callback)
|
||||
|
||||
self.engine.shutdown()
|
||||
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
_logger.info('Cross graph optimization does not preserve any states by itself. Loading the state of inner engine: %s', self.engine)
|
||||
return self.engine.load_state_dict(state_dict)
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
return self.engine.state_dict()
|
||||
|
||||
def set_engine(self, engine: ExecutionEngine) -> None:
|
||||
super().set_engine(engine)
|
||||
self.engine.register_model_event_callback(ModelEventType.TrainingEnd, self._training_end_callback)
|
||||
self.engine.register_model_event_callback(ModelEventType.FinalMetric, self._final_metric_callback)
|
||||
self.engine.register_model_event_callback(ModelEventType.IntermediateMetric, self._intermediate_metric_callback)
|
||||
|
||||
def add_optimizer(self, opt):
|
||||
self._optimizers.append(opt)
|
||||
|
||||
def submit_models(self, *models: GraphModelSpace) -> None:
|
||||
if any(not isinstance(model, GraphModelSpace) for model in models):
|
||||
raise TypeError('Cross graph optimization only supports GraphModelSpace.')
|
||||
curr_time = time.time()
|
||||
_logger.info('%d models are submitted.', len(models))
|
||||
with self._queue_lock:
|
||||
self._queuing_models.extend([(curr_time, _) for _ in models])
|
||||
self._history.extend(models)
|
||||
|
||||
def _submit_retry_models(self, models: List[GraphModelSpace]) -> None:
|
||||
_logger.info('%d models are retried.', len(models))
|
||||
with self._queue_lock:
|
||||
self._models_to_retry.extend(models)
|
||||
|
||||
def _consume_models(self):
|
||||
# a thread to monitor self._models_to_retry and self._queuing_models to consume them in batch
|
||||
while not self._stopped:
|
||||
# retrying jobs should be first scheduled.
|
||||
while self._models_to_retry:
|
||||
with self._queue_lock:
|
||||
# Get next model and lock the resource.
|
||||
if len(self.available_devices) > 0:
|
||||
m = self._models_to_retry[0]
|
||||
self._models_to_retry = self._models_to_retry[1:]
|
||||
m = self._schedule_models_in_batch(m)
|
||||
else:
|
||||
break
|
||||
|
||||
# submit the single model to avoid cross-graph optimization.
|
||||
self.engine.submit_models(*m)
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
# Submit merged models
|
||||
merged_models = []
|
||||
|
||||
with self._queue_lock:
|
||||
curr_time = time.time()
|
||||
|
||||
num_models_to_submit = len(self.available_devices)
|
||||
if self.max_concurrency is not None:
|
||||
num_models_to_submit = min(num_models_to_submit, self.max_concurrency)
|
||||
|
||||
if self._queuing_models and curr_time - self._queuing_models[0][0] >= self._batch_waiting_time:
|
||||
num_models_to_submit = min(num_models_to_submit, len(self._queuing_models))
|
||||
if num_models_to_submit > 0:
|
||||
merged_models = list(self._schedule_models_in_batch(*[_[1] for _ in self._queuing_models[:num_models_to_submit]]))
|
||||
self._queuing_models = self._queuing_models[num_models_to_submit:]
|
||||
_logger.debug('Scheduled %d models in batch.', num_models_to_submit)
|
||||
|
||||
# Outside lock to avoid deadlock.
|
||||
if merged_models:
|
||||
self.engine.submit_models(*merged_models)
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
def _schedule_models_in_batch(self, *models: GraphModelSpace) -> Iterable[GraphModelSpace]:
|
||||
_logger.info('%d models are scheduled in batch.', len(models))
|
||||
_logger.debug('Scheduled model ids: %s', [m.model_id for m in models])
|
||||
for model in models:
|
||||
model.status = ModelStatus.Training
|
||||
logical = self._build_logical(models)
|
||||
|
||||
for opt in self._optimizers:
|
||||
opt.convert(logical)
|
||||
|
||||
for model, grouped_models in self._assemble(logical):
|
||||
assert model.placement is not None
|
||||
_logger.debug('Created grouped model %d. Original model ids: %s', model.model_id, [m.model_id for m in grouped_models])
|
||||
|
||||
# unique non-cpu devices used by the trial
|
||||
self._trial_used_devices[model.model_id] = list(set([_ for _ in model.placement.values() if isinstance(_, GPUDevice)]))
|
||||
_logger.debug('Model %d uses devices: %s', model.model_id, self._trial_used_devices[model.model_id])
|
||||
|
||||
# currently, it is impossible for search strategy to submit models more than the number of available devices
|
||||
for used_device in self._trial_used_devices[model.model_id]:
|
||||
self.available_devices.remove(used_device) # used_device must be in self.available_devices
|
||||
self._running_models[model.model_id] = model
|
||||
|
||||
self._trial_to_original_models[model.model_id] = []
|
||||
for m in grouped_models:
|
||||
self._original_models[m.model_id] = m
|
||||
self._original_model_to_multi_model[m.model_id] = model
|
||||
self._trial_to_original_models[model.model_id].append(m.model_id)
|
||||
|
||||
yield model
|
||||
|
||||
def list_models(self) -> Iterable[GraphModelSpace]:
|
||||
return self._history
|
||||
|
||||
def idle_worker_available(self) -> bool:
|
||||
# the _queuing_models need to use available_devices first
|
||||
with self._queue_lock:
|
||||
available_for_more_models = len(self.available_devices) - len(self._queuing_models) - len(self._models_to_retry)
|
||||
return available_for_more_models
|
||||
|
||||
def budget_available(self) -> bool:
|
||||
return self.engine.budget_available()
|
||||
|
||||
def _assemble(self, logical_plan: LogicalPlan) -> Iterable[Tuple[GraphModelSpace, List[GraphModelSpace]]]:
|
||||
"""
|
||||
Return the assembled models as a list of tuple.
|
||||
Each tuple contains the assembled model, the device placement of graph nodes, and the original models.
|
||||
"""
|
||||
# try to use the available_devices first so that it can be launched as early as possible
|
||||
# if free devices are not enough to assemble all models in one trial, try all devices
|
||||
if len(self.available_devices) > 0:
|
||||
grouped_models: List[Dict[GraphModelSpace, Device]] = AssemblePolicy().group(logical_plan, self.available_devices)
|
||||
|
||||
if len(self.available_devices) == 0 or len(grouped_models) > 1:
|
||||
grouped_models: List[Dict[GraphModelSpace, Device]] = AssemblePolicy().group(logical_plan, self.all_devices)
|
||||
|
||||
for multi_model in grouped_models:
|
||||
model, model_placement = logical_plan.assemble(multi_model)
|
||||
assert isinstance(model, GraphModelSpace), 'Assembled model must be a GraphModelSpace.'
|
||||
|
||||
from nni.nas.evaluator.pytorch import Lightning
|
||||
from .evaluator import MultiModelLightningModule, MultiModelTrainer
|
||||
|
||||
if not isinstance(model.evaluator, Lightning):
|
||||
raise TypeError('Cross-graph optimization only supports pytorch lighting as evaluator.')
|
||||
if not isinstance(model.evaluator.module, MultiModelLightningModule):
|
||||
raise TypeError('Cross-graph optimization only support MultiModelLightningModule')
|
||||
if not isinstance(model.evaluator.trainer, MultiModelTrainer):
|
||||
raise TypeError('Cross-graph optimization only support MultiModelTrainer')
|
||||
|
||||
# Set n_models of the lightning module.
|
||||
model.evaluator.module.n_models = len(multi_model)
|
||||
model.status = ModelStatus.Frozen
|
||||
model.placement = model_placement
|
||||
model.metrics.strict = False
|
||||
|
||||
yield model, multi_model.keys()
|
||||
|
||||
def _build_logical(self, models: List[GraphModelSpace]) -> LogicalPlan:
|
||||
assert len(models) > 0
|
||||
logical_plan = LogicalPlan(model_cls=models[0].__class__, plan_id=self.logical_plan_counter)
|
||||
for model in models:
|
||||
logical_plan.add_model(model)
|
||||
self.logical_plan_counter += 1
|
||||
return logical_plan
|
||||
|
||||
def _training_end_callback(self, event: TrainingEndEvent) -> None:
|
||||
model = cast(GraphModelSpace, event.model)
|
||||
_logger.debug(f'Training end for merged model {model.model_id}.')
|
||||
model = self._running_models[model.model_id]
|
||||
models_to_retry = []
|
||||
for model_id in self._original_model_to_multi_model:
|
||||
if self._original_model_to_multi_model[model_id] == model:
|
||||
original_model = self._original_models[model_id]
|
||||
if model.status == ModelStatus.Trained:
|
||||
self.dispatch_model_event(TrainingEndEvent(original_model, ModelStatus.Trained))
|
||||
else:
|
||||
# the failed models in a multi-model will be retried one by one w/o CGO
|
||||
if len(self._trial_to_original_models[model.model_id]) > 1:
|
||||
# TODO: should the listeners be notified?
|
||||
original_model.status = ModelStatus.Frozen
|
||||
original_model.metrics.clear()
|
||||
models_to_retry.append(original_model)
|
||||
else:
|
||||
self.dispatch_model_event(TrainingEndEvent(original_model, ModelStatus.Failed))
|
||||
|
||||
if len(models_to_retry) > 0:
|
||||
self._submit_retry_models(models_to_retry)
|
||||
|
||||
self.available_devices.extend(self._trial_used_devices[model.model_id])
|
||||
self.available_devices = sorted(list(set(self.available_devices)))
|
||||
del self._running_models[model.model_id]
|
||||
|
||||
def _intermediate_metric_callback(self, event: IntermediateMetricEvent) -> None:
|
||||
model = cast(GraphModelSpace, event.model)
|
||||
metrics = cast(List[TrialMetric], event.metric)
|
||||
_logger.debug(f'Received intermediate metrics for merged model {model.model_id}: {metrics}')
|
||||
if not isinstance(metrics, Iterable):
|
||||
raise TypeError('Intermediate metrics must be a list of TrialMetric.')
|
||||
if len(metrics) != len(self._trial_to_original_models[model.model_id]):
|
||||
raise ValueError('Number of intermediate metrics must be equal to number of original models.')
|
||||
|
||||
merged_metrics: Dict[int, TrialMetric] = {}
|
||||
for idx, _ in enumerate(metrics):
|
||||
merged_metrics[self._trial_to_original_models[model.model_id][idx]] = metrics[idx]
|
||||
for model_id in merged_metrics:
|
||||
self.dispatch_model_event(IntermediateMetricEvent(self._original_models[model_id], merged_metrics[model_id]))
|
||||
|
||||
def _final_metric_callback(self, event: GraphModelSpace) -> None:
|
||||
model = cast(GraphModelSpace, event.model)
|
||||
metrics = cast(List[TrialMetric], event.metric.final)
|
||||
_logger.debug(f'Received final metrics for merged model {model.model_id}: {metrics}')
|
||||
if not isinstance(metrics, Iterable):
|
||||
raise TypeError('Final metrics must be a list of TrialMetric.')
|
||||
if len(metrics) != len(self._trial_to_original_models[model.model_id]):
|
||||
raise ValueError('Number of final metrics must be equal to number of original models.')
|
||||
|
||||
merged_metrics: Dict[int, TrialMetric] = {}
|
||||
for idx, _ in enumerate(metrics):
|
||||
merged_metrics[self._trial_to_original_models[model.model_id][idx]] = metrics[idx]
|
||||
|
||||
_logger.debug(f'Mapped to metrics of original models: {merged_metrics}')
|
||||
|
||||
for model_id in merged_metrics:
|
||||
self.dispatch_model_event(FinalMetricEvent(self._original_models[model_id], merged_metrics[model_id]))
|
||||
|
||||
|
||||
class AssemblePolicy:
|
||||
@staticmethod
|
||||
def _is_related_node(model: GraphModelSpace, node: Node):
|
||||
if isinstance(node, AbstractLogicalNode):
|
||||
if model in node.related_models:
|
||||
return True
|
||||
else:
|
||||
if model == node.graph.model:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _check_graph_connectivity(model: GraphModelSpace,
|
||||
group_model: Dict[GraphModelSpace, Device],
|
||||
logical_plan: LogicalPlan) -> bool:
|
||||
for edge in logical_plan.logical_graph.edges:
|
||||
if AssemblePolicy._is_related_node(model, edge.head) or \
|
||||
AssemblePolicy._is_related_node(model, edge.tail):
|
||||
for grouped_model in group_model:
|
||||
if AssemblePolicy._is_related_node(grouped_model, edge.head) or \
|
||||
AssemblePolicy._is_related_node(grouped_model, edge.tail):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _check_evaluator(new_model: GraphModelSpace, group_model: Dict[GraphModelSpace, Device]) -> bool:
|
||||
from nni.nas.evaluator.pytorch import Lightning
|
||||
from .evaluator import MultiModelLightningModule, MultiModelTrainer
|
||||
|
||||
if not (isinstance(new_model.evaluator, Lightning)
|
||||
and isinstance(new_model.evaluator.module, MultiModelLightningModule)
|
||||
and isinstance(new_model.evaluator.trainer, MultiModelTrainer)):
|
||||
return False
|
||||
for m in group_model:
|
||||
if not m.evaluator == new_model.evaluator:
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def group(logical_plan, available_devices):
|
||||
# TODO: Packing multiple model in one GPU
|
||||
# Currently, we only support one model per GPU
|
||||
all_grouped_models = []
|
||||
group_model = {}
|
||||
assert(len(available_devices) > 0) # There should be at least 1 device, set in CGO_DEVICES
|
||||
for idx, m in enumerate(logical_plan.models):
|
||||
# models in one group should
|
||||
# (1) not use more GPUs than available_devices
|
||||
# (2) be connected in the logical plan (independent models should be assembled in multiple groups)
|
||||
# (3) use same MultiModelSupervisedLearningModule
|
||||
if len(group_model) > 0 and \
|
||||
(AssemblePolicy._check_graph_connectivity(m, group_model, logical_plan) == False or
|
||||
AssemblePolicy._check_evaluator(m, group_model) == False):
|
||||
all_grouped_models.append(group_model)
|
||||
group_model = {}
|
||||
group_model[m] = available_devices[idx % len(available_devices)]
|
||||
if len(group_model) == len(available_devices) or \
|
||||
idx == len(logical_plan.models) - 1:
|
||||
all_grouped_models.append(group_model)
|
||||
group_model = {}
|
||||
return all_grouped_models
|
|
@ -1,402 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = ['CGOExecutionEngine', 'TrialSubmission']
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
import time
|
||||
import threading
|
||||
from typing import Iterable, List, Dict, Tuple, cast
|
||||
from dataclasses import dataclass
|
||||
|
||||
from nni.common.device import GPUDevice, Device
|
||||
from nni.experiment.config.training_services import RemoteConfig
|
||||
from nni.nas import utils
|
||||
from nni.nas.execution.common import (
|
||||
AbstractExecutionEngine, AbstractGraphListener, WorkerInfo,
|
||||
Model, ModelStatus, MetricData, Node,
|
||||
RetiariiAdvisor, send_trial, receive_trial_parameters, get_advisor,
|
||||
)
|
||||
from nni.nas.execution.pytorch import codegen
|
||||
from nni.nas.evaluator.pytorch.lightning import Lightning
|
||||
from nni.nas.evaluator.pytorch.cgo.evaluator import _MultiModelSupervisedLearningModule
|
||||
from nni.nas.execution.pytorch.graph import BaseGraphData
|
||||
|
||||
from .logical_optimizer.logical_plan import LogicalPlan, AbstractLogicalNode
|
||||
from .logical_optimizer.opt_dedup_input import DedupInputOptimizer
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _noop(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrialSubmission:
|
||||
model: Model
|
||||
placement: Dict[Node, Device]
|
||||
grouped_models: List[Model]
|
||||
|
||||
class CGOExecutionEngine(AbstractExecutionEngine):
|
||||
"""
|
||||
The execution engine with Cross-Graph Optimization (CGO).
|
||||
|
||||
Only models using PyTorch Lighting and MultiModelSupervisedLearningModule as the evaluator can be optimized.
|
||||
Otherwise, a model will be submitted independently without any cross-graph optimization.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
training_service
|
||||
The remote training service config.
|
||||
max_concurrency
|
||||
The maximum number of trials to run concurrently.
|
||||
batch_waiting_time
|
||||
Seconds to wait for each batch of trial submission.
|
||||
The trials within one batch could apply cross-graph optimization.
|
||||
rest_port
|
||||
The port of the experiment's rest server
|
||||
rest_url_prefix
|
||||
The url prefix of the experiment's rest entry
|
||||
"""
|
||||
|
||||
def __init__(self, training_service: RemoteConfig,
|
||||
max_concurrency: int = None,
|
||||
batch_waiting_time: int = 60,
|
||||
rest_port: int | None = None,
|
||||
rest_url_prefix: str | None = None
|
||||
) -> None:
|
||||
self.port = rest_port
|
||||
self.url_prefix = rest_url_prefix
|
||||
|
||||
self._listeners: List[AbstractGraphListener] = []
|
||||
self._running_models: Dict[int, Model] = dict()
|
||||
self.logical_plan_counter = 0
|
||||
self.available_devices: List[Device] = []
|
||||
self.max_concurrency: int = max_concurrency
|
||||
|
||||
devices = self._construct_devices(training_service)
|
||||
for device in devices:
|
||||
self.available_devices.append(device)
|
||||
self.all_devices = self.available_devices.copy()
|
||||
|
||||
self._batch_waiting_time = batch_waiting_time # seconds to wait for all models in a batch to do cross-graph optimization
|
||||
self._optimizers = [DedupInputOptimizer()]
|
||||
self._original_models = {}
|
||||
self._original_model_to_multi_model = {}
|
||||
self._trial_to_original_models = {}
|
||||
self._trial_used_devices: Dict[int, List[Device]] = {}
|
||||
|
||||
self._history: List[Model] = []
|
||||
|
||||
self._queuing_models: List[Model] = []
|
||||
self._models_to_retry: List[Model] = []
|
||||
self._queue_lock = threading.Lock()
|
||||
|
||||
# register advisor callbacks
|
||||
advisor: RetiariiAdvisor = get_advisor()
|
||||
advisor.register_callbacks({
|
||||
'send_trial': _noop,
|
||||
'request_trial_jobs': _noop,
|
||||
'trial_end': self._trial_end_callback,
|
||||
'intermediate_metric': self._intermediate_metric_callback,
|
||||
'final_metric': self._final_metric_callback
|
||||
})
|
||||
|
||||
self._stopped = False
|
||||
self._consumer_thread = threading.Thread(target=self._consume_models)
|
||||
self._consumer_thread.start()
|
||||
|
||||
def _construct_devices(self, training_service):
|
||||
devices = []
|
||||
if hasattr(training_service, 'machine_list'):
|
||||
for machine in cast(RemoteConfig, training_service).machine_list:
|
||||
assert machine.gpu_indices is not None, \
|
||||
'gpu_indices must be set in RemoteMachineConfig for CGO execution engine'
|
||||
assert isinstance(machine.gpu_indices, list), 'gpu_indices must be a list'
|
||||
for gpu_idx in machine.gpu_indices:
|
||||
devices.append(GPUDevice(machine.host, gpu_idx))
|
||||
return devices
|
||||
|
||||
def join(self):
|
||||
self._stopped = True
|
||||
self._consumer_thread.join()
|
||||
|
||||
def add_optimizer(self, opt):
|
||||
self._optimizers.append(opt)
|
||||
|
||||
def submit_models(self, *models: List[Model]) -> None:
|
||||
curr_time = time.time()
|
||||
_logger.info('%d models are submitted', len(models))
|
||||
self._queue_lock.acquire()
|
||||
self._queuing_models.extend([(curr_time, _) for _ in models])
|
||||
self._queue_lock.release()
|
||||
|
||||
def _submit_retry_models(self, models: List[Model]) -> None:
|
||||
_logger.info('%d models are retried', len(models))
|
||||
self._queue_lock.acquire()
|
||||
self._models_to_retry.extend(models)
|
||||
self._queue_lock.release()
|
||||
|
||||
def _consume_models(self):
|
||||
# a thread to monitor self._models_to_retry and self._queuing_models to consume them in batch
|
||||
while not self._stopped:
|
||||
if len(self._models_to_retry) > 0:
|
||||
self._queue_lock.acquire()
|
||||
# retrying jobs should be first scheduled.
|
||||
for m in self._models_to_retry:
|
||||
if len(self.available_devices) > 0:
|
||||
self._submit_models_in_batch(m) # submit the single model to avoid cross-graph optimization.
|
||||
self._models_to_retry = self._models_to_retry[1:]
|
||||
self._queue_lock.release()
|
||||
|
||||
if len(self._queuing_models) > 0:
|
||||
self._queue_lock.acquire()
|
||||
curr_time = time.time()
|
||||
|
||||
num_models_to_submit = len(self.available_devices)
|
||||
if self.max_concurrency:
|
||||
num_models_to_submit = min(num_models_to_submit, self.max_concurrency)
|
||||
|
||||
if curr_time - self._queuing_models[0][0] > self._batch_waiting_time:
|
||||
num_models_to_submit = min(num_models_to_submit, len(self._queuing_models))
|
||||
if num_models_to_submit > 0:
|
||||
self._submit_models_in_batch(*[_[1] for _ in self._queuing_models[:num_models_to_submit]])
|
||||
self._queuing_models = self._queuing_models[num_models_to_submit:]
|
||||
self._queue_lock.release()
|
||||
time.sleep(1)
|
||||
|
||||
def _extract_placement_constaint(self, placement_mapping: Dict[Node, Device]):
|
||||
unique_gpus = sorted(list(set([e for e in placement_mapping.values() if isinstance(e, GPUDevice)])))
|
||||
placement_constraint = None
|
||||
if len(unique_gpus) > 0:
|
||||
placement_constraint = {}
|
||||
placement_constraint['type'] = 'Device'
|
||||
placement_constraint['gpus'] = [(e.node_id, e.gpu_id) for e in unique_gpus]
|
||||
return placement_constraint
|
||||
|
||||
def _submit_models_in_batch(self, *models: List[Model]) -> None:
|
||||
_logger.info('%d models are submitted in batch', len(models))
|
||||
_logger.debug('model id: %s', str([m.model_id for m in models]))
|
||||
logical = self._build_logical(models)
|
||||
|
||||
for opt in self._optimizers:
|
||||
opt.convert(logical)
|
||||
|
||||
phy_models_and_placements = self._assemble(logical)
|
||||
for model, placement, grouped_models in phy_models_and_placements:
|
||||
data = BaseGraphData(codegen.model_to_pytorch_script(model, placement=placement), model.evaluator, {})
|
||||
placement_constraint = self._extract_placement_constaint(placement)
|
||||
trial_id = send_trial(data.dump(), placement_constraint=placement_constraint)
|
||||
# unique non-cpu devices used by the trial
|
||||
self._trial_used_devices[trial_id] = list(set([_ for _ in placement.values() if isinstance(_, GPUDevice)]))
|
||||
|
||||
# currently, it is impossible for search strategy to submit models more than the number of available devices
|
||||
for used_device in self._trial_used_devices[trial_id]:
|
||||
self.available_devices.remove(used_device) # used_device must be in self.available_devices
|
||||
self._running_models[trial_id] = model
|
||||
|
||||
self._trial_to_original_models[trial_id] = []
|
||||
for m in grouped_models:
|
||||
self._original_models[m.model_id] = m
|
||||
self._original_model_to_multi_model[m.model_id] = model
|
||||
self._trial_to_original_models[trial_id].append(m.model_id)
|
||||
self._history.append(m)
|
||||
|
||||
def list_models(self) -> Iterable[Model]:
|
||||
return self._history
|
||||
|
||||
def _assemble(self, logical_plan: LogicalPlan) -> List[Tuple[Model, Dict[Node, Device], List[Model]]]:
|
||||
"""
|
||||
Return the assembled models as a list of tuple.
|
||||
Each tuple contains the assembled model, the device placement of graph nodes, and the original models.
|
||||
"""
|
||||
# try to use the available_devices first so that it can be launched as early as possible
|
||||
# if free devices are not enough to assemble all models in one trial, try all devices
|
||||
if len(self.available_devices) > 0:
|
||||
grouped_models: List[Dict[Model, Device]] = AssemblePolicy().group(logical_plan, self.available_devices)
|
||||
|
||||
if len(self.available_devices) == 0 or len(grouped_models) > 1:
|
||||
grouped_models: List[Dict[Model, Device]] = AssemblePolicy().group(logical_plan, self.all_devices)
|
||||
|
||||
phy_models_and_placements = []
|
||||
for multi_model in grouped_models:
|
||||
model, model_placement = logical_plan.assemble(multi_model)
|
||||
assert isinstance(model.evaluator, Lightning), \
|
||||
"cross-graph optimization only supports pytorch lighting as evaluator"
|
||||
assert isinstance(model.evaluator.module, _MultiModelSupervisedLearningModule), \
|
||||
"cross-graph optimization only support MultiModelSupervisedLearningModule"
|
||||
|
||||
# replace the module with a new instance whose n_models is set
|
||||
# n_models must be set in __init__, otherwise it cannot be captured by serialize_cls
|
||||
new_module_init_params = model.evaluator.module.dump_kwargs().copy()
|
||||
|
||||
# MultiModelSupervisedLearningModule hides n_models of _MultiModelSupervisedLearningModule from users
|
||||
new_module_init_params['n_models'] = len(multi_model)
|
||||
new_module = _MultiModelSupervisedLearningModule(**new_module_init_params)
|
||||
model.evaluator.module = new_module
|
||||
phy_models_and_placements.append((model, model_placement, multi_model.keys()))
|
||||
return phy_models_and_placements
|
||||
|
||||
def _build_logical(self, models: List[Model]) -> LogicalPlan:
|
||||
logical_plan = LogicalPlan(plan_id=self.logical_plan_counter)
|
||||
for model in models:
|
||||
logical_plan.add_model(model)
|
||||
self.logical_plan_counter += 1
|
||||
return logical_plan
|
||||
|
||||
def register_graph_listener(self, listener: AbstractGraphListener) -> None:
|
||||
self._listeners.append(listener)
|
||||
|
||||
# def _send_trial_callback(self, paramater: dict) -> None:
|
||||
# if len(self.available_devices) == 0:
|
||||
# _logger.warning('There is no available devices, but trial is submitted.')
|
||||
# _logger.debug('Resource used. Remaining: %d', len(self.available_devices))
|
||||
|
||||
# def _request_trial_jobs_callback(self, num_trials: int) -> None:
|
||||
# self.resources += num_trials
|
||||
# _logger.info('on_resource_available: %d', self.resources)
|
||||
|
||||
def _trial_end_callback(self, trial_id: int, success: bool) -> None:
|
||||
model = self._running_models[trial_id]
|
||||
if success:
|
||||
model.status = ModelStatus.Trained
|
||||
else:
|
||||
model.status = ModelStatus.Failed
|
||||
models_to_retry = []
|
||||
for model_id in self._original_model_to_multi_model:
|
||||
if self._original_model_to_multi_model[model_id] == model:
|
||||
original_model = self._original_models[model_id]
|
||||
if success:
|
||||
original_model.status = ModelStatus.Trained
|
||||
else:
|
||||
original_model.status = ModelStatus.Failed
|
||||
# the failed models in a multi-model will be retried one by one w/o CGO
|
||||
if len(self._trial_to_original_models[trial_id]) > 1:
|
||||
models_to_retry.append(original_model)
|
||||
for listener in self._listeners:
|
||||
listener.on_training_end(original_model, success)
|
||||
|
||||
if len(models_to_retry) > 0:
|
||||
self._submit_retry_models(models_to_retry)
|
||||
|
||||
self.available_devices.extend(self._trial_used_devices[trial_id])
|
||||
self.available_devices = sorted(list(set(self.available_devices)))
|
||||
del self._running_models[trial_id]
|
||||
|
||||
def _intermediate_metric_callback(self, trial_id: int, metrics: MetricData) -> None:
|
||||
merged_metrics = {}
|
||||
for idx, _ in enumerate(metrics):
|
||||
merged_metrics[self._trial_to_original_models[trial_id][idx]] = metrics[idx]
|
||||
for model_id in merged_metrics:
|
||||
self._original_models[model_id].intermediate_metrics.append(merged_metrics[model_id])
|
||||
for listener in self._listeners:
|
||||
listener.on_intermediate_metric(self._original_models[model_id], merged_metrics[model_id])
|
||||
|
||||
def _final_metric_callback(self, trial_id: int, metrics: MetricData) -> None:
|
||||
_logger.debug(metrics)
|
||||
|
||||
if isinstance(metrics, float):
|
||||
self._listeners[0].on_metric(self._running_models[trial_id], metrics)
|
||||
else:
|
||||
merged_metrics = {}
|
||||
for idx, _ in enumerate(metrics):
|
||||
merged_metrics[self._trial_to_original_models[trial_id][idx]] = metrics[idx]
|
||||
for model_id in merged_metrics:
|
||||
self._original_models[model_id].metric = merged_metrics[model_id]
|
||||
for listener in self._listeners:
|
||||
listener.on_metric(self._original_models[model_id], merged_metrics[model_id])
|
||||
|
||||
def query_available_resource(self) -> List[WorkerInfo]:
|
||||
# the _queuing_models need to use available_devices first
|
||||
self._queue_lock.acquire()
|
||||
available_for_more_models = len(self.available_devices) - len(self._queuing_models) - len(self._models_to_retry)
|
||||
self._queue_lock.release()
|
||||
return available_for_more_models
|
||||
|
||||
def budget_exhausted(self) -> bool:
|
||||
advisor = get_advisor()
|
||||
return advisor.stopping
|
||||
|
||||
@classmethod
|
||||
def trial_execute_graph(cls) -> None:
|
||||
"""
|
||||
Initialize the model, hand it over to trainer.
|
||||
"""
|
||||
graph_data = BaseGraphData.load(receive_trial_parameters())
|
||||
_logger.info('CGO_ENGINE trial parameters received')
|
||||
random_str = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(6))
|
||||
file_name = f'_generated_model/{random_str}.py'
|
||||
os.makedirs(os.path.dirname(file_name), exist_ok=True)
|
||||
with open(file_name, 'w') as f:
|
||||
f.write(graph_data.model_script)
|
||||
|
||||
trainer_instance = graph_data.evaluator
|
||||
model_cls = utils.import_(f'_generated_model.{random_str}._model')
|
||||
|
||||
trainer_instance.fit(model_cls())
|
||||
os.remove(file_name)
|
||||
|
||||
|
||||
class AssemblePolicy:
|
||||
@staticmethod
|
||||
def _is_related_node(model: Model, node: Node):
|
||||
if isinstance(node, AbstractLogicalNode):
|
||||
if model in node.related_models:
|
||||
return True
|
||||
else:
|
||||
if model == node.graph.model:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _check_graph_connectivity(model: Model,
|
||||
group_model: Dict[Model, Device],
|
||||
logical_plan: LogicalPlan) -> bool:
|
||||
for edge in logical_plan.logical_graph.edges:
|
||||
if AssemblePolicy._is_related_node(model, edge.head) or \
|
||||
AssemblePolicy._is_related_node(model, edge.tail):
|
||||
for grouped_model in group_model:
|
||||
if AssemblePolicy._is_related_node(grouped_model, edge.head) or \
|
||||
AssemblePolicy._is_related_node(grouped_model, edge.tail):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _check_evaluator(new_model: Model, group_model: Dict[Model, Device]) -> bool:
|
||||
if not (isinstance(new_model.evaluator, Lightning)
|
||||
and isinstance(new_model.evaluator.module, _MultiModelSupervisedLearningModule)):
|
||||
return False
|
||||
for m in group_model:
|
||||
if not m.evaluator == new_model.evaluator:
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def group(logical_plan, available_devices):
|
||||
# TODO: Packing multiple model in one GPU
|
||||
# Currently, we only support one model per GPU
|
||||
all_grouped_models = []
|
||||
group_model = {}
|
||||
assert(len(available_devices) > 0) # There should be at least 1 device, set in CGO_DEVICES
|
||||
for idx, m in enumerate(logical_plan.models):
|
||||
# models in one group should
|
||||
# (1) not use more GPUs than available_devices
|
||||
# (2) be connected in the logical plan (independent models should be assembled in multiple groups)
|
||||
# (3) use same MultiModelSupervisedLearningModule
|
||||
if len(group_model) > 0 and \
|
||||
(AssemblePolicy._check_graph_connectivity(m, group_model, logical_plan) == False or
|
||||
AssemblePolicy._check_evaluator(m, group_model) == False):
|
||||
all_grouped_models.append(group_model)
|
||||
group_model = {}
|
||||
group_model[m] = available_devices[idx % len(available_devices)]
|
||||
if len(group_model) == len(available_devices) or \
|
||||
idx == len(logical_plan.models) - 1:
|
||||
all_grouped_models.append(group_model)
|
||||
group_model = {}
|
||||
return all_grouped_models
|
|
@ -0,0 +1,41 @@
|
|||
{
|
||||
"_model": {
|
||||
"inputs": ["image"],
|
||||
"outputs": ["metric"],
|
||||
|
||||
"nodes": {
|
||||
"stem": {"operation": {"type": "_cell", "cell_name": "stem"}},
|
||||
"flatten": {"operation": {"type": "__torch__.torch.nn.Flatten"}},
|
||||
"fc1": {"operation": {"type": "__torch__.torch.nn.Linear", "parameters": {"out_features": 256, "in_features": 1024}}},
|
||||
"fc2": {"operation": {"type": "__torch__.torch.nn.Linear", "parameters": {"out_features": 10, "in_features": 256}}},
|
||||
"softmax": {"operation": {"type": "__torch__.torch.nn.Softmax"}}
|
||||
},
|
||||
|
||||
"edges": [
|
||||
{"head": ["_inputs", 0], "tail": ["stem", null]},
|
||||
{"head": ["stem", null], "tail": ["flatten", null]},
|
||||
{"head": ["flatten", null], "tail": ["fc1", null]},
|
||||
{"head": ["fc1", null], "tail": ["fc2", null]},
|
||||
{"head": ["fc2", null], "tail": ["softmax", null]},
|
||||
{"head": ["softmax", null], "tail": ["_outputs", 0]}
|
||||
]
|
||||
},
|
||||
|
||||
"stem": {
|
||||
"nodes": {
|
||||
"conv1": {"operation": {"type": "__torch__.torch.nn.Conv2d", "parameters": {"out_channels": 32, "in_channels": 1, "kernel_size": 5}}},
|
||||
"pool1": {"operation": {"type": "__torch__.torch.nn.MaxPool2d", "parameters": {"kernel_size": 2}}},
|
||||
"conv2": {"operation": {"type": "__torch__.torch.nn.Conv2d", "parameters": {"out_channels": 64, "in_channels": 32, "kernel_size": 5}}},
|
||||
"pool2": {"operation": {"type": "__torch__.torch.nn.MaxPool2d", "parameters": {"kernel_size": 2}}}
|
||||
},
|
||||
|
||||
"edges": [
|
||||
{"head": ["_inputs", 0], "tail": ["conv1", null]},
|
||||
{"head": ["conv1", null], "tail": ["pool1", null]},
|
||||
{"head": ["pool1", null], "tail": ["conv2", null]},
|
||||
{"head": ["conv2", null], "tail": ["pool2", null]},
|
||||
{"head": ["pool2", null], "tail": ["_outputs", 0]}
|
||||
]
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,270 @@
|
|||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchmetrics
|
||||
from torchvision.datasets import MNIST
|
||||
from torchvision import transforms
|
||||
from pytorch_lightning.utilities.seed import seed_everything
|
||||
|
||||
import nni
|
||||
from nni.experiment.config import RemoteConfig, RemoteMachineConfig
|
||||
from nni.nas.evaluator.pytorch import Lightning, DataLoader
|
||||
from nni.nas.execution import SequentialExecutionEngine
|
||||
from nni.nas.execution.cgo import CrossGraphOptimization
|
||||
from nni.nas.execution.cgo.evaluator import MultiModelLightningModule, MultiModelTrainer
|
||||
from nni.nas.execution.cgo.logical_optimizer.logical_plan import LogicalPlan
|
||||
from nni.nas.execution.cgo.logical_optimizer.opt_dedup_input import DedupInputOptimizer
|
||||
from nni.nas.space import Node, ModelStatus
|
||||
from nni.nas.space.pytorch import PytorchGraphModelSpace
|
||||
from nni.runtime.trial_command_channel import get_default_trial_command_channel, set_default_trial_command_channel
|
||||
|
||||
from ut.sdk.helper.trial_command_channel import TestHelperTrialCommandChannel
|
||||
|
||||
|
||||
class _model_cpu(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.M_1_stem = M_1_stem()
|
||||
self.M_2_stem = M_2_stem()
|
||||
self.M_1_flatten = torch.nn.Flatten()
|
||||
self.M_2_flatten = torch.nn.Flatten()
|
||||
self.M_1_fc1 = torch.nn.Linear(out_features=256, in_features=1024)
|
||||
self.M_2_fc1 = torch.nn.Linear(out_features=256, in_features=1024)
|
||||
self.M_1_fc2 = torch.nn.Linear(out_features=10, in_features=256)
|
||||
self.M_2_fc2 = torch.nn.Linear(out_features=10, in_features=256)
|
||||
self.M_1_softmax = torch.nn.Softmax()
|
||||
self.M_2_softmax = torch.nn.Softmax()
|
||||
|
||||
def forward(self, *_inputs):
|
||||
M_1__inputs_to_M_2_stem = _inputs[0]
|
||||
M_1_stem = self.M_1_stem(_inputs[0])
|
||||
M_2_stem = self.M_2_stem(M_1__inputs_to_M_2_stem)
|
||||
M_1_flatten = self.M_1_flatten(M_1_stem)
|
||||
M_2_flatten = self.M_2_flatten(M_2_stem)
|
||||
M_1_fc1 = self.M_1_fc1(M_1_flatten)
|
||||
M_2_fc1 = self.M_2_fc1(M_2_flatten)
|
||||
M_1_fc2 = self.M_1_fc2(M_1_fc1)
|
||||
M_2_fc2 = self.M_2_fc2(M_2_fc1)
|
||||
M_1_softmax = self.M_1_softmax(M_1_fc2)
|
||||
M_2_softmax = self.M_2_softmax(M_2_fc2)
|
||||
return M_1_softmax, M_2_softmax
|
||||
|
||||
|
||||
class _model_gpu(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.M_1_stem = M_1_stem().to('cuda:0')
|
||||
self.M_2_stem = M_2_stem().to('cuda:1')
|
||||
self.M_1_flatten = torch.nn.Flatten().to('cuda:0')
|
||||
self.M_2_flatten = torch.nn.Flatten().to('cuda:1')
|
||||
self.M_1_fc1 = torch.nn.Linear(out_features=256, in_features=1024).to('cuda:0')
|
||||
self.M_2_fc1 = torch.nn.Linear(out_features=256, in_features=1024).to('cuda:1')
|
||||
self.M_1_fc2 = torch.nn.Linear(out_features=10, in_features=256).to('cuda:0')
|
||||
self.M_2_fc2 = torch.nn.Linear(out_features=10, in_features=256).to('cuda:1')
|
||||
self.M_1_softmax = torch.nn.Softmax().to('cuda:0')
|
||||
self.M_2_softmax = torch.nn.Softmax().to('cuda:1')
|
||||
|
||||
def forward(self, *_inputs):
|
||||
M_1__inputs_to_M_1_stem = _inputs[0].to("cuda:0")
|
||||
M_1__inputs_to_M_2_stem = _inputs[0].to("cuda:1")
|
||||
M_1_stem = self.M_1_stem(M_1__inputs_to_M_1_stem)
|
||||
M_2_stem = self.M_2_stem(M_1__inputs_to_M_2_stem)
|
||||
M_1_flatten = self.M_1_flatten(M_1_stem)
|
||||
M_2_flatten = self.M_2_flatten(M_2_stem)
|
||||
M_1_fc1 = self.M_1_fc1(M_1_flatten)
|
||||
M_2_fc1 = self.M_2_fc1(M_2_flatten)
|
||||
M_1_fc2 = self.M_1_fc2(M_1_fc1)
|
||||
M_2_fc2 = self.M_2_fc2(M_2_fc1)
|
||||
M_1_softmax = self.M_1_softmax(M_1_fc2)
|
||||
M_2_softmax = self.M_2_softmax(M_2_fc2)
|
||||
return M_1_softmax, M_2_softmax
|
||||
|
||||
|
||||
class M_1_stem(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv2d(out_channels=32, in_channels=1, kernel_size=5)
|
||||
self.pool1 = torch.nn.MaxPool2d(kernel_size=2)
|
||||
self.conv2 = torch.nn.Conv2d(out_channels=64, in_channels=32, kernel_size=5)
|
||||
self.pool2 = torch.nn.MaxPool2d(kernel_size=2)
|
||||
|
||||
def forward(self, *_inputs):
|
||||
conv1 = self.conv1(_inputs[0])
|
||||
pool1 = self.pool1(conv1)
|
||||
conv2 = self.conv2(pool1)
|
||||
pool2 = self.pool2(conv2)
|
||||
return pool2
|
||||
|
||||
|
||||
class M_2_stem(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv2d(out_channels=32, in_channels=1, kernel_size=5)
|
||||
self.pool1 = torch.nn.MaxPool2d(kernel_size=2)
|
||||
self.conv2 = torch.nn.Conv2d(out_channels=64, in_channels=32, kernel_size=5)
|
||||
self.pool2 = torch.nn.MaxPool2d(kernel_size=2)
|
||||
|
||||
def forward(self, *_inputs):
|
||||
conv1 = self.conv1(_inputs[0])
|
||||
pool1 = self.pool1(conv1)
|
||||
conv2 = self.conv2(pool1)
|
||||
pool2 = self.pool2(conv2)
|
||||
return pool2
|
||||
|
||||
|
||||
def create_evaluator(n_models=None):
|
||||
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
|
||||
train_dataset = nni.trace(MNIST)(root='data/mnist', train=True, download=False, transform=transform)
|
||||
test_dataset = nni.trace(MNIST)(root='data/mnist', train=False, download=False, transform=transform)
|
||||
|
||||
multi_module = MultiModelLightningModule(
|
||||
nn.CrossEntropyLoss(),
|
||||
torchmetrics.Accuracy('multiclass', num_classes=10),
|
||||
n_models=n_models
|
||||
)
|
||||
|
||||
lightning = Lightning(
|
||||
multi_module,
|
||||
MultiModelTrainer(max_epochs=1, limit_train_batches=0.25, enable_progress_bar=True),
|
||||
train_dataloaders=DataLoader(train_dataset, batch_size=100),
|
||||
val_dataloaders=DataLoader(test_dataset, batch_size=100)
|
||||
)
|
||||
return lightning
|
||||
|
||||
|
||||
def _load_mnist(n_models: int = 1):
|
||||
path = Path(__file__).parent / 'mnist_pytorch.json'
|
||||
with open(path) as f:
|
||||
mnist_model = PytorchGraphModelSpace._load(**nni.load(fp=f))
|
||||
mnist_model.evaluator = create_evaluator()
|
||||
mnist_model.status = ModelStatus.Frozen
|
||||
|
||||
if n_models == 1:
|
||||
return mnist_model
|
||||
else:
|
||||
models = [mnist_model]
|
||||
for _ in range(n_models - 1):
|
||||
forked_model = mnist_model.fork()
|
||||
forked_model.status = ModelStatus.Frozen
|
||||
models.append(forked_model)
|
||||
return models
|
||||
|
||||
|
||||
def _build_logical_with_mnist(n_models: int):
|
||||
lp = LogicalPlan(model_cls=PytorchGraphModelSpace)
|
||||
models = _load_mnist(n_models=n_models)
|
||||
for m in models:
|
||||
lp.add_model(m)
|
||||
return lp, models
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def seed():
|
||||
seed_everything(42)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trial_command_channel():
|
||||
_default_channel = get_default_trial_command_channel()
|
||||
channel = TestHelperTrialCommandChannel()
|
||||
set_default_trial_command_channel(channel)
|
||||
|
||||
nni.get_next_parameter()
|
||||
|
||||
yield channel
|
||||
|
||||
set_default_trial_command_channel(_default_channel)
|
||||
|
||||
|
||||
@pytest.fixture(params=[1, 2, 4])
|
||||
def cgo(request):
|
||||
remote = RemoteConfig(machine_list=[])
|
||||
remote.machine_list.append(RemoteMachineConfig(host='test', gpu_indices=list(range(request.param))))
|
||||
|
||||
cgo = CrossGraphOptimization(remote_config=remote, batch_waiting_time=0)
|
||||
|
||||
yield cgo
|
||||
|
||||
cgo.shutdown()
|
||||
|
||||
|
||||
def test_multi_model_trainer_cpu(trial_command_channel):
|
||||
evaluator = create_evaluator(n_models=2)
|
||||
evaluator.evaluate(_model_cpu())
|
||||
|
||||
result = trial_command_channel.final
|
||||
assert len(result) == 2
|
||||
|
||||
for _ in result:
|
||||
assert _ > 0.8
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='test requires GPU and torch+cuda')
|
||||
def test_multi_model_trainer_gpu(trial_command_channel):
|
||||
evaluator = create_evaluator(n_models=2)
|
||||
evaluator.evaluate(_model_gpu())
|
||||
|
||||
result = trial_command_channel.final
|
||||
assert len(result) == 2
|
||||
|
||||
for _ in result:
|
||||
assert _ > 0.8
|
||||
|
||||
|
||||
def test_add_model():
|
||||
lp, models = _build_logical_with_mnist(3)
|
||||
|
||||
for node in lp.logical_graph.hidden_nodes:
|
||||
old_nodes = [m.root_graph.get_node_by_id(node.id) for m in models]
|
||||
|
||||
assert any([old_nodes[0].__repr__() == Node.__repr__(x) for x in old_nodes])
|
||||
|
||||
|
||||
def test_dedup_input(cgo):
|
||||
lp, _ = _build_logical_with_mnist(3)
|
||||
|
||||
opt = DedupInputOptimizer()
|
||||
opt.convert(lp)
|
||||
|
||||
phy_models = cgo._assemble(lp)
|
||||
|
||||
if len(cgo.available_devices) == 4:
|
||||
assert len(list(phy_models)) == 1
|
||||
elif len(cgo.available_devices) == 2:
|
||||
assert len(list(phy_models)) == 2
|
||||
elif len(cgo.available_devices) == 1:
|
||||
assert len(list(phy_models)) == 3
|
||||
else:
|
||||
raise ValueError(f'Invalid device count: {cgo.available_devices}')
|
||||
|
||||
cgo.shutdown()
|
||||
|
||||
|
||||
def test_submit_models(cgo):
|
||||
import logging
|
||||
logging.getLogger('nni.nas.execution.sequential').setLevel(logging.DEBUG)
|
||||
|
||||
models = _load_mnist(2)
|
||||
|
||||
engine = SequentialExecutionEngine(continue_on_failure=True)
|
||||
|
||||
cgo.set_engine(engine)
|
||||
cgo.submit_models(*models)
|
||||
|
||||
cgo.wait_models()
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
for model in models: # can't be trained without gpu.
|
||||
assert model.status == ModelStatus.Failed
|
||||
if len(cgo.available_devices) == 1:
|
||||
assert engine._model_count == 2 # 2 single
|
||||
else:
|
||||
assert engine._model_count == 3 # 1 + retry 2
|
||||
elif torch.cuda.device_count() == 1 and len(cgo.available_devices) == 1:
|
||||
# Should be the case on pipeline.
|
||||
assert engine._model_count == 2 # No merge at all.
|
||||
for model in models:
|
||||
assert model.status == ModelStatus.Trained
|
||||
assert model.metrics.final > 0.8
|
|
@ -1,366 +0,0 @@
|
|||
import os
|
||||
import threading
|
||||
import unittest
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pytorch_lightning.utilities.seed import seed_everything
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import nni
|
||||
from nni.experiment.config import RemoteConfig, RemoteMachineConfig
|
||||
from nni.runtime.tuner_command_channel import legacy as protocol
|
||||
import json
|
||||
|
||||
try:
|
||||
from nni.common.device import GPUDevice
|
||||
from nni.retiarii.execution.cgo_engine import CGOExecutionEngine
|
||||
from nni.retiarii import Model
|
||||
from nni.retiarii.graph import Node
|
||||
|
||||
from nni.retiarii import Model, submit_models
|
||||
from nni.retiarii.integration import RetiariiAdvisor
|
||||
from nni.retiarii.execution import set_execution_engine
|
||||
from nni.retiarii.execution.logical_optimizer.opt_dedup_input import DedupInputOptimizer
|
||||
from nni.retiarii.execution.logical_optimizer.logical_plan import LogicalPlan
|
||||
from nni.retiarii.utils import import_
|
||||
|
||||
from nni.retiarii import serialize
|
||||
import nni.retiarii.evaluator.pytorch.lightning as pl
|
||||
from nni.retiarii.evaluator.pytorch.cgo.evaluator import MultiModelSupervisedLearningModule, _MultiModelSupervisedLearningModule
|
||||
import nni.retiarii.evaluator.pytorch.cgo.trainer as cgo_trainer
|
||||
|
||||
import nni.retiarii.integration_api
|
||||
|
||||
module_import_failed = False
|
||||
except ImportError:
|
||||
module_import_failed = True
|
||||
|
||||
import pytest
|
||||
from torchvision.datasets import MNIST
|
||||
from torchvision import transforms
|
||||
from torch.utils.data import Dataset
|
||||
from sklearn.datasets import load_diabetes
|
||||
|
||||
pytestmark = pytest.mark.skip(reason='Will be rewritten.')
|
||||
|
||||
|
||||
class _model_cpu(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.M_1_stem = M_1_stem()
|
||||
self.M_2_stem = M_2_stem()
|
||||
self.M_1_flatten = torch.nn.Flatten()
|
||||
self.M_2_flatten = torch.nn.Flatten()
|
||||
self.M_1_fc1 = torch.nn.Linear(out_features=256, in_features=1024)
|
||||
self.M_2_fc1 = torch.nn.Linear(out_features=256, in_features=1024)
|
||||
self.M_1_fc2 = torch.nn.Linear(out_features=10, in_features=256)
|
||||
self.M_2_fc2 = torch.nn.Linear(out_features=10, in_features=256)
|
||||
self.M_1_softmax = torch.nn.Softmax()
|
||||
self.M_2_softmax = torch.nn.Softmax()
|
||||
|
||||
def forward(self, *_inputs):
|
||||
M_1__inputs_to_M_2_stem = _inputs[0]
|
||||
M_1_stem = self.M_1_stem(_inputs[0])
|
||||
M_2_stem = self.M_2_stem(M_1__inputs_to_M_2_stem)
|
||||
M_1_flatten = self.M_1_flatten(M_1_stem)
|
||||
M_2_flatten = self.M_2_flatten(M_2_stem)
|
||||
M_1_fc1 = self.M_1_fc1(M_1_flatten)
|
||||
M_2_fc1 = self.M_2_fc1(M_2_flatten)
|
||||
M_1_fc2 = self.M_1_fc2(M_1_fc1)
|
||||
M_2_fc2 = self.M_2_fc2(M_2_fc1)
|
||||
M_1_softmax = self.M_1_softmax(M_1_fc2)
|
||||
M_2_softmax = self.M_2_softmax(M_2_fc2)
|
||||
return M_1_softmax, M_2_softmax
|
||||
|
||||
|
||||
class _model_gpu(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.M_1_stem = M_1_stem().to('cuda:0')
|
||||
self.M_2_stem = M_2_stem().to('cuda:1')
|
||||
self.M_1_flatten = torch.nn.Flatten().to('cuda:0')
|
||||
self.M_2_flatten = torch.nn.Flatten().to('cuda:1')
|
||||
self.M_1_fc1 = torch.nn.Linear(out_features=256, in_features=1024).to('cuda:0')
|
||||
self.M_2_fc1 = torch.nn.Linear(out_features=256, in_features=1024).to('cuda:1')
|
||||
self.M_1_fc2 = torch.nn.Linear(out_features=10, in_features=256).to('cuda:0')
|
||||
self.M_2_fc2 = torch.nn.Linear(out_features=10, in_features=256).to('cuda:1')
|
||||
self.M_1_softmax = torch.nn.Softmax().to('cuda:0')
|
||||
self.M_2_softmax = torch.nn.Softmax().to('cuda:1')
|
||||
|
||||
def forward(self, *_inputs):
|
||||
M_1__inputs_to_M_1_stem = _inputs[0].to("cuda:0")
|
||||
M_1__inputs_to_M_2_stem = _inputs[0].to("cuda:1")
|
||||
M_1_stem = self.M_1_stem(M_1__inputs_to_M_1_stem)
|
||||
M_2_stem = self.M_2_stem(M_1__inputs_to_M_2_stem)
|
||||
M_1_flatten = self.M_1_flatten(M_1_stem)
|
||||
M_2_flatten = self.M_2_flatten(M_2_stem)
|
||||
M_1_fc1 = self.M_1_fc1(M_1_flatten)
|
||||
M_2_fc1 = self.M_2_fc1(M_2_flatten)
|
||||
M_1_fc2 = self.M_1_fc2(M_1_fc1)
|
||||
M_2_fc2 = self.M_2_fc2(M_2_fc1)
|
||||
M_1_softmax = self.M_1_softmax(M_1_fc2)
|
||||
M_2_softmax = self.M_2_softmax(M_2_fc2)
|
||||
return M_1_softmax, M_2_softmax
|
||||
|
||||
|
||||
class M_1_stem(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv2d(out_channels=32, in_channels=1, kernel_size=5)
|
||||
self.pool1 = torch.nn.MaxPool2d(kernel_size=2)
|
||||
self.conv2 = torch.nn.Conv2d(out_channels=64, in_channels=32, kernel_size=5)
|
||||
self.pool2 = torch.nn.MaxPool2d(kernel_size=2)
|
||||
|
||||
def forward(self, *_inputs):
|
||||
conv1 = self.conv1(_inputs[0])
|
||||
pool1 = self.pool1(conv1)
|
||||
conv2 = self.conv2(pool1)
|
||||
pool2 = self.pool2(conv2)
|
||||
return pool2
|
||||
|
||||
|
||||
class M_2_stem(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv2d(out_channels=32, in_channels=1, kernel_size=5)
|
||||
self.pool1 = torch.nn.MaxPool2d(kernel_size=2)
|
||||
self.conv2 = torch.nn.Conv2d(out_channels=64, in_channels=32, kernel_size=5)
|
||||
self.pool2 = torch.nn.MaxPool2d(kernel_size=2)
|
||||
|
||||
def forward(self, *_inputs):
|
||||
conv1 = self.conv1(_inputs[0])
|
||||
pool1 = self.pool1(conv1)
|
||||
conv2 = self.conv2(pool1)
|
||||
pool2 = self.pool2(conv2)
|
||||
return pool2
|
||||
|
||||
|
||||
def _reset():
|
||||
# this is to not affect other tests in sdk
|
||||
nni.trial._intermediate_seq = 0
|
||||
nni.trial._params = {'foo': 'bar', 'parameter_id': 0, 'parameters': {}}
|
||||
nni.runtime.platform.test._last_metric = None
|
||||
nni.retiarii.integration_api._advisor = None
|
||||
nni.retiarii.execution.api._execution_engine = None
|
||||
|
||||
seed_everything(42)
|
||||
|
||||
|
||||
def _new_trainer():
|
||||
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
|
||||
train_dataset = serialize(MNIST, root='data/mnist', train=True, download=True, transform=transform)
|
||||
test_dataset = serialize(MNIST, root='data/mnist', train=False, download=True, transform=transform)
|
||||
|
||||
multi_module = _MultiModelSupervisedLearningModule(nn.CrossEntropyLoss, {'acc': pl.AccuracyWithLogits})
|
||||
|
||||
lightning = pl.Lightning(multi_module, cgo_trainer.Trainer(use_cgo=True,
|
||||
max_epochs=1,
|
||||
limit_train_batches=0.25,
|
||||
enable_progress_bar=False),
|
||||
train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
|
||||
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100))
|
||||
return lightning
|
||||
|
||||
|
||||
def _load_mnist(n_models: int = 1):
|
||||
path = Path('ut/nas/mnist_pytorch.json')
|
||||
with open(path) as f:
|
||||
mnist_model = Model._load(nni.load(fp=f))
|
||||
mnist_model.evaluator = _new_trainer()
|
||||
|
||||
if n_models == 1:
|
||||
return mnist_model
|
||||
else:
|
||||
models = [mnist_model]
|
||||
for i in range(n_models - 1):
|
||||
forked_model = mnist_model.fork()
|
||||
forked_model.evaluator = _new_trainer()
|
||||
models.append(forked_model)
|
||||
return models
|
||||
|
||||
|
||||
def _get_final_result():
|
||||
result = nni.load(nni.runtime.platform.test._last_metric)['value']
|
||||
if isinstance(result, list):
|
||||
return [float(_) for _ in result]
|
||||
else:
|
||||
if isinstance(result, str) and '[' in result:
|
||||
return nni.load(result)
|
||||
return [float(result)]
|
||||
|
||||
|
||||
class CGOEngineTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
if module_import_failed:
|
||||
self.skipTest('test skip due to failed import of nni.retiarii.evaluator.pytorch.lightning')
|
||||
|
||||
def test_multi_model_trainer_cpu(self):
|
||||
_reset()
|
||||
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
|
||||
train_dataset = serialize(MNIST, root='data/mnist', train=True, download=True, transform=transform)
|
||||
test_dataset = serialize(MNIST, root='data/mnist', train=False, download=True, transform=transform)
|
||||
|
||||
multi_module = _MultiModelSupervisedLearningModule(nn.CrossEntropyLoss, {'acc': pl.AccuracyWithLogits}, n_models=2)
|
||||
|
||||
lightning = pl.Lightning(multi_module, cgo_trainer.Trainer(use_cgo=True,
|
||||
max_epochs=1,
|
||||
limit_train_batches=0.25),
|
||||
train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
|
||||
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100))
|
||||
|
||||
lightning._execute(_model_cpu)
|
||||
|
||||
result = _get_final_result()
|
||||
assert len(result) == 2
|
||||
|
||||
for _ in result:
|
||||
assert _ > 0.8
|
||||
|
||||
def test_multi_model_trainer_gpu(self):
|
||||
_reset()
|
||||
if not (torch.cuda.is_available() and torch.cuda.device_count() >= 2):
|
||||
pytest.skip('test requires GPU and torch+cuda')
|
||||
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
|
||||
train_dataset = serialize(MNIST, root='data/mnist', train=True, download=True, transform=transform)
|
||||
test_dataset = serialize(MNIST, root='data/mnist', train=False, download=True, transform=transform)
|
||||
|
||||
multi_module = _MultiModelSupervisedLearningModule(nn.CrossEntropyLoss, {'acc': pl.AccuracyWithLogits}, n_models=2)
|
||||
|
||||
lightning = pl.Lightning(multi_module, cgo_trainer.Trainer(use_cgo=True,
|
||||
max_epochs=1,
|
||||
limit_train_batches=0.25),
|
||||
train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
|
||||
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100))
|
||||
|
||||
lightning._execute(_model_gpu)
|
||||
|
||||
result = _get_final_result()
|
||||
assert len(result) == 2
|
||||
|
||||
for _ in result:
|
||||
assert _ > 0.8
|
||||
|
||||
def _build_logical_with_mnist(self, n_models: int):
|
||||
lp = LogicalPlan()
|
||||
models = _load_mnist(n_models=n_models)
|
||||
for m in models:
|
||||
lp.add_model(m)
|
||||
return lp, models
|
||||
|
||||
def test_add_model(self):
|
||||
_reset()
|
||||
|
||||
lp, models = self._build_logical_with_mnist(3)
|
||||
|
||||
for node in lp.logical_graph.hidden_nodes:
|
||||
old_nodes = [m.root_graph.get_node_by_id(node.id) for m in models]
|
||||
|
||||
self.assertTrue(any([old_nodes[0].__repr__() == Node.__repr__(x) for x in old_nodes]))
|
||||
|
||||
def test_dedup_input_four_devices(self):
|
||||
_reset()
|
||||
|
||||
lp, models = self._build_logical_with_mnist(3)
|
||||
|
||||
opt = DedupInputOptimizer()
|
||||
opt.convert(lp)
|
||||
|
||||
advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
|
||||
advisor._channel = protocol.LegacyCommandChannel()
|
||||
advisor.default_worker.start()
|
||||
advisor.assessor_worker.start()
|
||||
|
||||
remote = RemoteConfig(machine_list=[])
|
||||
remote.machine_list.append(RemoteMachineConfig(host='test', gpu_indices=[0,1,2,3]))
|
||||
cgo = CGOExecutionEngine(training_service=remote, batch_waiting_time=0)
|
||||
|
||||
phy_models = cgo._assemble(lp)
|
||||
self.assertTrue(len(phy_models) == 1)
|
||||
advisor.stopping = True
|
||||
advisor.default_worker.join()
|
||||
advisor.assessor_worker.join()
|
||||
cgo.join()
|
||||
|
||||
def test_dedup_input_two_devices(self):
|
||||
_reset()
|
||||
|
||||
lp, models = self._build_logical_with_mnist(3)
|
||||
|
||||
opt = DedupInputOptimizer()
|
||||
opt.convert(lp)
|
||||
|
||||
advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
|
||||
advisor._channel = protocol.LegacyCommandChannel()
|
||||
advisor.default_worker.start()
|
||||
advisor.assessor_worker.start()
|
||||
|
||||
remote = RemoteConfig(machine_list=[])
|
||||
remote.machine_list.append(RemoteMachineConfig(host='test', gpu_indices=[0,1]))
|
||||
cgo = CGOExecutionEngine(training_service=remote, batch_waiting_time=0)
|
||||
|
||||
phy_models = cgo._assemble(lp)
|
||||
self.assertTrue(len(phy_models) == 2)
|
||||
advisor.stopping = True
|
||||
advisor.default_worker.join()
|
||||
advisor.assessor_worker.join()
|
||||
cgo.join()
|
||||
|
||||
def test_submit_models(self):
|
||||
_reset()
|
||||
os.makedirs('generated', exist_ok=True)
|
||||
import nni.runtime.platform.test as tt
|
||||
protocol._set_out_file(open('generated/debug_protocol_out_file.py', 'wb'))
|
||||
protocol._set_in_file(open('generated/debug_protocol_out_file.py', 'rb'))
|
||||
|
||||
models = _load_mnist(2)
|
||||
|
||||
advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
|
||||
advisor._channel = protocol.LegacyCommandChannel()
|
||||
advisor.default_worker.start()
|
||||
advisor.assessor_worker.start()
|
||||
# this is because RetiariiAdvisor only works after `_advisor_initialized` becomes True.
|
||||
# normally it becomes true when `handle_request_trial_jobs` is invoked
|
||||
advisor._advisor_initialized = True
|
||||
|
||||
remote = RemoteConfig(machine_list=[])
|
||||
remote.machine_list.append(RemoteMachineConfig(host='test', gpu_indices=[0,1,2,3]))
|
||||
cgo_engine = CGOExecutionEngine(training_service=remote, batch_waiting_time=0)
|
||||
set_execution_engine(cgo_engine)
|
||||
submit_models(*models)
|
||||
time.sleep(3)
|
||||
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() >= 2:
|
||||
cmd, data = protocol.receive()
|
||||
params = nni.load(data)
|
||||
|
||||
tt.init_params(params)
|
||||
|
||||
trial_thread = threading.Thread(target=CGOExecutionEngine.trial_execute_graph)
|
||||
trial_thread.start()
|
||||
last_metric = None
|
||||
while True:
|
||||
time.sleep(1)
|
||||
if tt._last_metric:
|
||||
metric = tt.get_last_metric()
|
||||
if metric == last_metric:
|
||||
continue
|
||||
if 'value' in metric:
|
||||
metric['value'] = json.dumps(metric['value'])
|
||||
advisor.handle_report_metric_data(metric)
|
||||
last_metric = metric
|
||||
if not trial_thread.is_alive():
|
||||
trial_thread.join()
|
||||
break
|
||||
|
||||
trial_thread.join()
|
||||
|
||||
advisor.stopping = True
|
||||
advisor.default_worker.join()
|
||||
advisor.assessor_worker.join()
|
||||
cgo_engine.join()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Загрузка…
Ссылка в новой задаче