зеркало из https://github.com/microsoft/archai.git
Refactor to add new Finalizers class
This commit is contained in:
Родитель
516261da7f
Коммит
8ff553fd36
|
@ -1,11 +1,8 @@
|
|||
from copy import deepcopy
|
||||
from typing import Callable, Iterable, List, Optional, Tuple
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from overrides import overrides, EnforceOverrides
|
||||
|
||||
import torch
|
||||
from torch import nn, tensor
|
||||
from overrides import overrides, EnforceOverrides
|
||||
|
||||
from ..common.common import logger
|
||||
from .dag_edge import DagEdge
|
||||
|
@ -18,16 +15,17 @@ class Cell(nn.Module, ABC, EnforceOverrides):
|
|||
alphas_cell:Optional['Cell']):
|
||||
super().__init__()
|
||||
|
||||
# some of these members are public as finalizer needs access
|
||||
self.shared_alphas = alphas_cell is not None
|
||||
self.desc = desc
|
||||
self._s0_op = Op.create(desc.s0_op, affine=affine)
|
||||
self._s1_op = Op.create(desc.s1_op, affine=affine)
|
||||
self.s0_op = Op.create(desc.s0_op, affine=affine)
|
||||
self.s1_op = Op.create(desc.s1_op, affine=affine)
|
||||
|
||||
self._dag = Cell._create_dag(desc.nodes(),
|
||||
self.dag = Cell._create_dag(desc.nodes(),
|
||||
affine=affine, droppath=droppath,
|
||||
alphas_cell=alphas_cell)
|
||||
|
||||
self._post_op = Op.create(desc.post_op, affine=affine)
|
||||
self.post_op = Op.create(desc.post_op, affine=affine)
|
||||
|
||||
@staticmethod
|
||||
def _create_dag(nodes_desc:List[NodeDesc],
|
||||
|
@ -41,34 +39,33 @@ class Cell(nn.Module, ABC, EnforceOverrides):
|
|||
for j, edge_desc in enumerate(node_desc.edges):
|
||||
edges.append(DagEdge(edge_desc,
|
||||
affine=affine, droppath=droppath,
|
||||
alphas_edge=alphas_cell._dag[i][j] if alphas_cell else None))
|
||||
alphas_edge=alphas_cell.dag[i][j] if alphas_cell else None))
|
||||
return dag
|
||||
|
||||
def alphas(self)->Iterable[nn.Parameter]:
|
||||
for node in self._dag:
|
||||
for node in self.dag:
|
||||
for edge in node:
|
||||
for alpha in edge.alphas():
|
||||
yield alpha
|
||||
|
||||
def weights(self)->Iterable[nn.Parameter]:
|
||||
for node in self._dag:
|
||||
for node in self.dag:
|
||||
for edge in node:
|
||||
for p in edge.weights():
|
||||
yield p
|
||||
|
||||
def ops(self)->Iterable[Op]:
|
||||
for node in self._dag:
|
||||
for node in self.dag:
|
||||
for edge in node:
|
||||
yield edge.op()
|
||||
|
||||
|
||||
@overrides
|
||||
def forward(self, s0, s1):
|
||||
s0 = self._s0_op(s0)
|
||||
s1 = self._s1_op(s1)
|
||||
s0 = self.s0_op(s0)
|
||||
s1 = self.s1_op(s1)
|
||||
|
||||
states = [s0, s1]
|
||||
for node in self._dag:
|
||||
for node in self.dag:
|
||||
# TODO: we should probably do average here otherwise output will
|
||||
# blow up as number of primitives grows
|
||||
# TODO: Current assumption is that each edge has k channel
|
||||
|
@ -83,33 +80,5 @@ class Cell(nn.Module, ABC, EnforceOverrides):
|
|||
|
||||
# TODO: Below assumes same shape except for channels but this won't
|
||||
# happen for max pool etc shapes? Also, remove hard coded 2.
|
||||
return self._post_op(states)
|
||||
return self.post_op(states)
|
||||
|
||||
def finalize(self)->CellDesc:
|
||||
nodes_desc:List[NodeDesc] = []
|
||||
for node in self._dag:
|
||||
edge_descs, edge_desc_ranks = [], []
|
||||
for edge in node:
|
||||
edge_desc, rank = edge.finalize()
|
||||
if rank is None:
|
||||
edge_descs.append(edge_desc) # required edge
|
||||
else: # optional edge
|
||||
edge_desc_ranks.append((edge_desc, rank))
|
||||
if len(edge_desc_ranks) > self.desc.max_final_edges:
|
||||
edge_desc_ranks.sort(key=lambda d:d[1], reverse=True)
|
||||
edge_desc_ranks = edge_desc_ranks[:self.desc.max_final_edges]
|
||||
edge_descs.extend((edr[0] for edr in edge_desc_ranks))
|
||||
nodes_desc.append(NodeDesc(edge_descs))
|
||||
|
||||
finalized = CellDesc(
|
||||
cell_type=self.desc.cell_type,
|
||||
id = self.desc.id,
|
||||
nodes = nodes_desc,
|
||||
s0_op=self._s0_op.finalize()[0],
|
||||
s1_op=self._s1_op.finalize()[0],
|
||||
alphas_from = self.desc.alphas_from,
|
||||
max_final_edges=self.desc.max_final_edges,
|
||||
node_ch_out=self.desc.node_ch_out,
|
||||
post_op=self._post_op.finalize()[0]
|
||||
)
|
||||
return finalized
|
||||
|
|
|
@ -5,6 +5,8 @@ from overrides import EnforceOverrides
|
|||
from .model_desc import ModelDesc, NodeDesc
|
||||
|
||||
class CellBuilder(ABC, EnforceOverrides):
|
||||
"""This is interface class for different NAS algorithms to implement"""
|
||||
|
||||
def register_ops(self)->None:
|
||||
pass
|
||||
|
||||
|
|
|
@ -16,21 +16,17 @@ class DagEdge(nn.Module):
|
|||
if droppath and self._op.can_drop_path():
|
||||
assert self.training
|
||||
self._wrapped = nn.Sequential(self._op, DropPath_())
|
||||
self._input_ids = desc.input_ids
|
||||
self.input_ids = desc.input_ids
|
||||
self.desc = desc
|
||||
|
||||
@overrides
|
||||
def forward(self, inputs:List[torch.Tensor]):
|
||||
if len(self._input_ids)==1:
|
||||
return self._wrapped(inputs[self._input_ids[0]])
|
||||
elif len(self._input_ids) == len(inputs): # for perf
|
||||
if len(self.input_ids)==1:
|
||||
return self._wrapped(inputs[self.input_ids[0]])
|
||||
elif len(self.input_ids) == len(inputs): # for perf
|
||||
return self._wrapped(inputs)
|
||||
else:
|
||||
return self._wrapped([inputs[i] for i in self._input_ids])
|
||||
|
||||
def finalize(self)->Tuple[EdgeDesc, Optional[float]]:
|
||||
op_desc, rank = self._op.finalize()
|
||||
return (EdgeDesc(op_desc, self._input_ids), rank)
|
||||
return self._wrapped([inputs[i] for i in self.input_ids])
|
||||
|
||||
def alphas(self)->Iterable[nn.Parameter]:
|
||||
for alpha in self._op.alphas():
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
from typing import List, Tuple, Optional, Iterator
|
||||
from overrides import EnforceOverrides
|
||||
|
||||
from torch import nn
|
||||
|
||||
from archai.nas.model import Model
|
||||
from archai.nas.cell import Cell
|
||||
from archai.nas.model_desc import CellDesc, ModelDesc, NodeDesc, EdgeDesc
|
||||
|
||||
class Finalizers(EnforceOverrides):
|
||||
"""Provides base algorithms for finalizing model, cell and edge which can be overriden"""
|
||||
|
||||
def finalize_model(self, model:Model, to_cpu=True, restore_device=True)->ModelDesc:
|
||||
# move model to CPU before finalize because each op will serialize
|
||||
# its parameters and we don't want copy of these parameters hanging on GPU
|
||||
original = model.device_type()
|
||||
if to_cpu:
|
||||
model.cpu()
|
||||
|
||||
# finalize will create copy of state and this can overflow GPU RAM
|
||||
assert model.device_type() == 'cpu'
|
||||
|
||||
cell_descs = self.finalize_cells(model)
|
||||
|
||||
if restore_device:
|
||||
model.to(original, non_blocking=True)
|
||||
|
||||
return ModelDesc(stem0_op=model.stem0_op.finalize()[0],
|
||||
stem1_op=model.stem1_op.finalize()[0],
|
||||
pool_op=model.pool_op.finalize()[0],
|
||||
ds_ch=model.desc.ds_ch,
|
||||
n_classes=model.desc.n_classes,
|
||||
cell_descs=cell_descs,
|
||||
aux_tower_descs=model.desc.aux_tower_descs,
|
||||
logits_op=model.logits_op.finalize()[0],
|
||||
params=model.desc.params)
|
||||
|
||||
def finalize_cells(self, model:Model)->List[CellDesc]:
|
||||
return [self.finalize_cell(cell) for cell in model.cells]
|
||||
|
||||
def finalize_cell(self, cell:Cell)->CellDesc:
|
||||
# first finalize each node, we will need to recreate node desc with final version
|
||||
node_descs:List[NodeDesc] = []
|
||||
for node in cell.dag:
|
||||
node_desc = self.finalize_node(node, cell.desc.max_final_edges)
|
||||
node_descs.append(node_desc)
|
||||
|
||||
finalized = CellDesc(
|
||||
cell_type=cell.desc.cell_type,
|
||||
id = cell.desc.id,
|
||||
nodes = node_descs,
|
||||
s0_op=cell.s0_op.finalize()[0],
|
||||
s1_op=cell.s1_op.finalize()[0],
|
||||
alphas_from = cell.desc.alphas_from,
|
||||
max_final_edges=cell.desc.max_final_edges,
|
||||
node_ch_out=cell.desc.node_ch_out,
|
||||
post_op=cell.post_op.finalize()[0]
|
||||
)
|
||||
return finalized
|
||||
|
||||
def finalize_node(self, node:nn.ModuleList, max_final_edges:int)->NodeDesc:
|
||||
# get edge ranks, if rank is None it is deemed as required
|
||||
pre_selected, edge_desc_ranks = self.get_edge_ranks(node)
|
||||
ranked_selected = self.select_edges(edge_desc_ranks, max_final_edges)
|
||||
selected_edges = pre_selected + ranked_selected
|
||||
return NodeDesc(selected_edges)
|
||||
|
||||
def select_edges(self, edge_desc_ranks:List[Tuple[EdgeDesc, float]],
|
||||
max_final_edges:int)->List[EdgeDesc]:
|
||||
if len(edge_desc_ranks) > max_final_edges:
|
||||
# sort by rank and pick bottom
|
||||
edge_desc_ranks.sort(key=lambda d:d[1], reverse=True)
|
||||
edge_desc_ranks = edge_desc_ranks[:max_final_edges]
|
||||
return [edr[0] for edr in edge_desc_ranks]
|
||||
|
||||
def get_edge_ranks(self, node:nn.ModuleList)\
|
||||
->Tuple[List[EdgeDesc], List[Tuple[EdgeDesc, float]]]:
|
||||
selected_edges, edge_desc_ranks = [], []
|
||||
for edge in node:
|
||||
edge_desc, rank = self.finalize_edge(edge)
|
||||
# if rank is None then it is required rank
|
||||
if rank is None:
|
||||
selected_edges.append(edge_desc) # required edge
|
||||
else: # optional edge
|
||||
edge_desc_ranks.append((edge_desc, rank))
|
||||
return selected_edges, edge_desc_ranks
|
||||
|
||||
def finalize_edge(self, edge)->Tuple[EdgeDesc, Optional[float]]:
|
||||
op_desc, rank = edge._op.finalize()
|
||||
return (EdgeDesc(op_desc, edge.input_ids), rank)
|
|
@ -20,11 +20,12 @@ class Model(nn.Module):
|
|||
def __init__(self, model_desc:ModelDesc, droppath:bool, affine:bool):
|
||||
super().__init__()
|
||||
|
||||
# some of these fields are public as finalizer needs access to them
|
||||
self.desc = model_desc
|
||||
self._stem0_op = Op.create(model_desc.stem0_op, affine=affine)
|
||||
self._stem1_op = Op.create(model_desc.stem1_op, affine=affine)
|
||||
self.stem0_op = Op.create(model_desc.stem0_op, affine=affine)
|
||||
self.stem1_op = Op.create(model_desc.stem1_op, affine=affine)
|
||||
|
||||
self._cells = nn.ModuleList()
|
||||
self.cells = nn.ModuleList()
|
||||
self._aux_towers = nn.ModuleList()
|
||||
|
||||
for i, (cell_desc, aux_tower_desc) in \
|
||||
|
@ -32,32 +33,30 @@ class Model(nn.Module):
|
|||
self._build_cell(cell_desc, aux_tower_desc, droppath, affine)
|
||||
|
||||
# adaptive pooling output size to 1x1
|
||||
self._pool_op = Op.create(model_desc.pool_op, affine=affine)
|
||||
self.pool_op = Op.create(model_desc.pool_op, affine=affine)
|
||||
# since ch_p records last cell's output channels
|
||||
# it indicates the input channel number
|
||||
self._logits_op = Op.create(model_desc.logits_op, affine=affine)
|
||||
self.logits_op = Op.create(model_desc.logits_op, affine=affine)
|
||||
|
||||
# for i,cell in enumerate(self._cells):
|
||||
# for i,cell in enumerate(self.cells):
|
||||
# print(i, ml_utils.param_size(cell))
|
||||
logger.info({'model_summary': self.summary()})
|
||||
|
||||
|
||||
|
||||
def _build_cell(self, cell_desc:CellDesc,
|
||||
aux_tower_desc:Optional[AuxTowerDesc],
|
||||
droppath:bool, affine:bool)->None:
|
||||
alphas_cell = None if cell_desc.alphas_from==cell_desc.id \
|
||||
else self._cells[cell_desc.alphas_from]
|
||||
else self.cells[cell_desc.alphas_from]
|
||||
cell = Cell(cell_desc, affine=affine, droppath=droppath,
|
||||
alphas_cell=alphas_cell)
|
||||
self._cells.append(cell)
|
||||
self.cells.append(cell)
|
||||
self._aux_towers.append(AuxTower(aux_tower_desc) \
|
||||
if aux_tower_desc else None)
|
||||
|
||||
def summary(self)->dict:
|
||||
return {
|
||||
'cell_count': len(self._cells),
|
||||
#'cell_params': [ml_utils.param_size(c) for c in self._cells]
|
||||
'cell_count': len(self.cells),
|
||||
#'cell_params': [ml_utils.param_size(c) for c in self.cells]
|
||||
'params': ml_utils.param_size(self),
|
||||
'alphas_p': len(list(a for a in self.alphas())),
|
||||
'alphas': np.sum(a.numel() for a in self.alphas()),
|
||||
|
@ -65,31 +64,31 @@ class Model(nn.Module):
|
|||
}
|
||||
|
||||
def alphas(self)->Iterable[nn.Parameter]:
|
||||
for cell in self._cells:
|
||||
for cell in self.cells:
|
||||
if not cell.shared_alphas:
|
||||
for alpha in cell.alphas():
|
||||
yield alpha
|
||||
|
||||
def weights(self)->Iterable[nn.Parameter]:
|
||||
for cell in self._cells:
|
||||
for cell in self.cells:
|
||||
for w in cell.weights():
|
||||
yield w
|
||||
|
||||
def ops(self)->Iterable[Op]:
|
||||
for cell in self._cells:
|
||||
for cell in self.cells:
|
||||
for op in cell.ops():
|
||||
yield op
|
||||
|
||||
@overrides
|
||||
def forward(self, x)->Tuple[Tensor, Optional[Tensor]]:
|
||||
#print(torch.cuda.memory_allocated()/1.0e6)
|
||||
s0 = self._stem0_op(x)
|
||||
s0 = self.stem0_op(x)
|
||||
#print(torch.cuda.memory_allocated()/1.0e6)
|
||||
s1 = self._stem1_op(x)
|
||||
s1 = self.stem1_op(x)
|
||||
#print(-1, s0.shape, s1.shape, torch.cuda.memory_allocated()/1.0e6)
|
||||
|
||||
logits_aux = None
|
||||
for ci, (cell, aux_tower) in enumerate(zip(self._cells, self._aux_towers)):
|
||||
for ci, (cell, aux_tower) in enumerate(zip(self.cells, self._aux_towers)):
|
||||
#print(s0.shape, s1.shape, end='')
|
||||
s0, s1 = s1, cell.forward(s0, s1)
|
||||
#print(ci, s0.shape, s1.shape, torch.cuda.memory_allocated()/1.0e6)
|
||||
|
@ -100,8 +99,8 @@ class Model(nn.Module):
|
|||
#print(ci, 'aux', logits_aux.shape)
|
||||
|
||||
# s1 is now the last cell's output
|
||||
out = self._pool_op(s1)
|
||||
logits = self._logits_op(out) # flatten
|
||||
out = self.pool_op(s1)
|
||||
logits = self.logits_op(out) # flatten
|
||||
|
||||
#print(-1, 'out', out.shape)
|
||||
#print(-1, 'logits', logits.shape)
|
||||
|
@ -111,31 +110,6 @@ class Model(nn.Module):
|
|||
def device_type(self)->str:
|
||||
return next(self.parameters()).device.type
|
||||
|
||||
def finalize(self, to_cpu=True, restore_device=True)->ModelDesc:
|
||||
# move model to CPU before finalize because each op will serialize
|
||||
# its parameters and we don't want copy of these parameters lying on GPU
|
||||
original = self.device_type()
|
||||
if to_cpu:
|
||||
self.cpu()
|
||||
|
||||
# finalize will create copy of state and this can overflow GPU RAM
|
||||
assert self.device_type() == 'cpu'
|
||||
|
||||
cell_descs = [cell.finalize() for cell in self._cells]
|
||||
|
||||
if restore_device:
|
||||
self.to(original, non_blocking=True)
|
||||
|
||||
return ModelDesc(stem0_op=self._stem0_op.finalize()[0],
|
||||
stem1_op=self._stem1_op.finalize()[0],
|
||||
pool_op=self._pool_op.finalize()[0],
|
||||
ds_ch=self.desc.ds_ch,
|
||||
n_classes=self.desc.n_classes,
|
||||
cell_descs=cell_descs,
|
||||
aux_tower_descs=self.desc.aux_tower_descs,
|
||||
logits_op=self._logits_op.finalize()[0],
|
||||
params=self.desc.params)
|
||||
|
||||
def drop_path_prob(self, p:float):
|
||||
""" Set drop path probability
|
||||
This will be called externally so any DropPath_ modules get
|
||||
|
@ -163,9 +137,9 @@ class AuxTower(nn.Module):
|
|||
nn.BatchNorm2d(768),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
self._logits_op = nn.Linear(768, aux_tower_desc.n_classes)
|
||||
self.logits_op = nn.Linear(768, aux_tower_desc.n_classes)
|
||||
|
||||
def forward(self, x:torch.Tensor):
|
||||
x = self.features(x)
|
||||
x = self._logits_op(x.view(x.size(0), -1))
|
||||
x = self.logits_op(x.view(x.size(0), -1))
|
||||
return x
|
|
@ -2,6 +2,7 @@ from typing import Iterator, Mapping, Type, Optional, Tuple, List
|
|||
import math
|
||||
import copy
|
||||
import random
|
||||
import os
|
||||
|
||||
import torch
|
||||
import tensorwatch as tw
|
||||
|
@ -11,16 +12,16 @@ import yaml
|
|||
from archai.common.common import logger
|
||||
from archai.common.checkpoint import CheckPoint
|
||||
from archai.common.config import Config
|
||||
from .cell_builder import CellBuilder
|
||||
from .arch_trainer import TArchTrainer
|
||||
from . import nas_utils
|
||||
from .model_desc import CellType, ModelDesc
|
||||
from archai.nas.cell_builder import CellBuilder
|
||||
from archai.nas.arch_trainer import TArchTrainer
|
||||
from archai.nas import nas_utils
|
||||
from archai.nas.model_desc import CellType, ModelDesc
|
||||
from archai.common.trainer import Trainer
|
||||
from archai.datasets import data
|
||||
from .model import Model
|
||||
from archai.nas.model import Model
|
||||
from archai.common.metrics import EpochMetrics, Metrics
|
||||
from archai.common import utils
|
||||
import os
|
||||
from archai.nas.finalizers import Finalizers
|
||||
|
||||
class MetricsStats:
|
||||
"""Holds model statistics and training metrics for given description"""
|
||||
|
@ -336,7 +337,8 @@ class Search:
|
|||
|
||||
@staticmethod
|
||||
def _create_metrics_stats(model:Model, train_metrics:Metrics)->MetricsStats:
|
||||
finalized = model.finalize(restore_device=False)
|
||||
finalizers = Finalizers()
|
||||
finalized = finalizers.finalize_model(model, restore_device=False)
|
||||
# model stats is doing some hooks so do it last
|
||||
model_stats = tw.ModelStats(model, [1,3,32,32],# TODO: remove this hard coding
|
||||
clone_model=True)
|
||||
|
|
Загрузка…
Ссылка в новой задаче