Refactor to add new Finalizers class

This commit is contained in:
Shital Shah 2020-05-06 00:58:28 -07:00
Родитель 516261da7f
Коммит 8ff553fd36
6 изменённых файлов: 141 добавлений и 108 удалений

Просмотреть файл

@ -1,11 +1,8 @@
from copy import deepcopy
from typing import Callable, Iterable, List, Optional, Tuple
from abc import ABC, abstractmethod
from overrides import overrides, EnforceOverrides
import torch
from torch import nn, tensor
from overrides import overrides, EnforceOverrides
from ..common.common import logger
from .dag_edge import DagEdge
@ -18,16 +15,17 @@ class Cell(nn.Module, ABC, EnforceOverrides):
alphas_cell:Optional['Cell']):
super().__init__()
# some of these members are public as finalizer needs access
self.shared_alphas = alphas_cell is not None
self.desc = desc
self._s0_op = Op.create(desc.s0_op, affine=affine)
self._s1_op = Op.create(desc.s1_op, affine=affine)
self.s0_op = Op.create(desc.s0_op, affine=affine)
self.s1_op = Op.create(desc.s1_op, affine=affine)
self._dag = Cell._create_dag(desc.nodes(),
self.dag = Cell._create_dag(desc.nodes(),
affine=affine, droppath=droppath,
alphas_cell=alphas_cell)
self._post_op = Op.create(desc.post_op, affine=affine)
self.post_op = Op.create(desc.post_op, affine=affine)
@staticmethod
def _create_dag(nodes_desc:List[NodeDesc],
@ -41,34 +39,33 @@ class Cell(nn.Module, ABC, EnforceOverrides):
for j, edge_desc in enumerate(node_desc.edges):
edges.append(DagEdge(edge_desc,
affine=affine, droppath=droppath,
alphas_edge=alphas_cell._dag[i][j] if alphas_cell else None))
alphas_edge=alphas_cell.dag[i][j] if alphas_cell else None))
return dag
def alphas(self)->Iterable[nn.Parameter]:
for node in self._dag:
for node in self.dag:
for edge in node:
for alpha in edge.alphas():
yield alpha
def weights(self)->Iterable[nn.Parameter]:
for node in self._dag:
for node in self.dag:
for edge in node:
for p in edge.weights():
yield p
def ops(self)->Iterable[Op]:
for node in self._dag:
for node in self.dag:
for edge in node:
yield edge.op()
@overrides
def forward(self, s0, s1):
s0 = self._s0_op(s0)
s1 = self._s1_op(s1)
s0 = self.s0_op(s0)
s1 = self.s1_op(s1)
states = [s0, s1]
for node in self._dag:
for node in self.dag:
# TODO: we should probably do average here otherwise output will
# blow up as number of primitives grows
# TODO: Current assumption is that each edge has k channel
@ -83,33 +80,5 @@ class Cell(nn.Module, ABC, EnforceOverrides):
# TODO: Below assumes same shape except for channels but this won't
# happen for max pool etc shapes? Also, remove hard coded 2.
return self._post_op(states)
return self.post_op(states)
def finalize(self)->CellDesc:
nodes_desc:List[NodeDesc] = []
for node in self._dag:
edge_descs, edge_desc_ranks = [], []
for edge in node:
edge_desc, rank = edge.finalize()
if rank is None:
edge_descs.append(edge_desc) # required edge
else: # optional edge
edge_desc_ranks.append((edge_desc, rank))
if len(edge_desc_ranks) > self.desc.max_final_edges:
edge_desc_ranks.sort(key=lambda d:d[1], reverse=True)
edge_desc_ranks = edge_desc_ranks[:self.desc.max_final_edges]
edge_descs.extend((edr[0] for edr in edge_desc_ranks))
nodes_desc.append(NodeDesc(edge_descs))
finalized = CellDesc(
cell_type=self.desc.cell_type,
id = self.desc.id,
nodes = nodes_desc,
s0_op=self._s0_op.finalize()[0],
s1_op=self._s1_op.finalize()[0],
alphas_from = self.desc.alphas_from,
max_final_edges=self.desc.max_final_edges,
node_ch_out=self.desc.node_ch_out,
post_op=self._post_op.finalize()[0]
)
return finalized

Просмотреть файл

@ -5,6 +5,8 @@ from overrides import EnforceOverrides
from .model_desc import ModelDesc, NodeDesc
class CellBuilder(ABC, EnforceOverrides):
"""This is interface class for different NAS algorithms to implement"""
def register_ops(self)->None:
pass

Просмотреть файл

@ -16,21 +16,17 @@ class DagEdge(nn.Module):
if droppath and self._op.can_drop_path():
assert self.training
self._wrapped = nn.Sequential(self._op, DropPath_())
self._input_ids = desc.input_ids
self.input_ids = desc.input_ids
self.desc = desc
@overrides
def forward(self, inputs:List[torch.Tensor]):
if len(self._input_ids)==1:
return self._wrapped(inputs[self._input_ids[0]])
elif len(self._input_ids) == len(inputs): # for perf
if len(self.input_ids)==1:
return self._wrapped(inputs[self.input_ids[0]])
elif len(self.input_ids) == len(inputs): # for perf
return self._wrapped(inputs)
else:
return self._wrapped([inputs[i] for i in self._input_ids])
def finalize(self)->Tuple[EdgeDesc, Optional[float]]:
op_desc, rank = self._op.finalize()
return (EdgeDesc(op_desc, self._input_ids), rank)
return self._wrapped([inputs[i] for i in self.input_ids])
def alphas(self)->Iterable[nn.Parameter]:
for alpha in self._op.alphas():

90
archai/nas/finalizers.py Normal file
Просмотреть файл

@ -0,0 +1,90 @@
from typing import List, Tuple, Optional, Iterator
from overrides import EnforceOverrides
from torch import nn
from archai.nas.model import Model
from archai.nas.cell import Cell
from archai.nas.model_desc import CellDesc, ModelDesc, NodeDesc, EdgeDesc
class Finalizers(EnforceOverrides):
"""Provides base algorithms for finalizing model, cell and edge which can be overriden"""
def finalize_model(self, model:Model, to_cpu=True, restore_device=True)->ModelDesc:
# move model to CPU before finalize because each op will serialize
# its parameters and we don't want copy of these parameters hanging on GPU
original = model.device_type()
if to_cpu:
model.cpu()
# finalize will create copy of state and this can overflow GPU RAM
assert model.device_type() == 'cpu'
cell_descs = self.finalize_cells(model)
if restore_device:
model.to(original, non_blocking=True)
return ModelDesc(stem0_op=model.stem0_op.finalize()[0],
stem1_op=model.stem1_op.finalize()[0],
pool_op=model.pool_op.finalize()[0],
ds_ch=model.desc.ds_ch,
n_classes=model.desc.n_classes,
cell_descs=cell_descs,
aux_tower_descs=model.desc.aux_tower_descs,
logits_op=model.logits_op.finalize()[0],
params=model.desc.params)
def finalize_cells(self, model:Model)->List[CellDesc]:
return [self.finalize_cell(cell) for cell in model.cells]
def finalize_cell(self, cell:Cell)->CellDesc:
# first finalize each node, we will need to recreate node desc with final version
node_descs:List[NodeDesc] = []
for node in cell.dag:
node_desc = self.finalize_node(node, cell.desc.max_final_edges)
node_descs.append(node_desc)
finalized = CellDesc(
cell_type=cell.desc.cell_type,
id = cell.desc.id,
nodes = node_descs,
s0_op=cell.s0_op.finalize()[0],
s1_op=cell.s1_op.finalize()[0],
alphas_from = cell.desc.alphas_from,
max_final_edges=cell.desc.max_final_edges,
node_ch_out=cell.desc.node_ch_out,
post_op=cell.post_op.finalize()[0]
)
return finalized
def finalize_node(self, node:nn.ModuleList, max_final_edges:int)->NodeDesc:
# get edge ranks, if rank is None it is deemed as required
pre_selected, edge_desc_ranks = self.get_edge_ranks(node)
ranked_selected = self.select_edges(edge_desc_ranks, max_final_edges)
selected_edges = pre_selected + ranked_selected
return NodeDesc(selected_edges)
def select_edges(self, edge_desc_ranks:List[Tuple[EdgeDesc, float]],
max_final_edges:int)->List[EdgeDesc]:
if len(edge_desc_ranks) > max_final_edges:
# sort by rank and pick bottom
edge_desc_ranks.sort(key=lambda d:d[1], reverse=True)
edge_desc_ranks = edge_desc_ranks[:max_final_edges]
return [edr[0] for edr in edge_desc_ranks]
def get_edge_ranks(self, node:nn.ModuleList)\
->Tuple[List[EdgeDesc], List[Tuple[EdgeDesc, float]]]:
selected_edges, edge_desc_ranks = [], []
for edge in node:
edge_desc, rank = self.finalize_edge(edge)
# if rank is None then it is required rank
if rank is None:
selected_edges.append(edge_desc) # required edge
else: # optional edge
edge_desc_ranks.append((edge_desc, rank))
return selected_edges, edge_desc_ranks
def finalize_edge(self, edge)->Tuple[EdgeDesc, Optional[float]]:
op_desc, rank = edge._op.finalize()
return (EdgeDesc(op_desc, edge.input_ids), rank)

Просмотреть файл

@ -20,11 +20,12 @@ class Model(nn.Module):
def __init__(self, model_desc:ModelDesc, droppath:bool, affine:bool):
super().__init__()
# some of these fields are public as finalizer needs access to them
self.desc = model_desc
self._stem0_op = Op.create(model_desc.stem0_op, affine=affine)
self._stem1_op = Op.create(model_desc.stem1_op, affine=affine)
self.stem0_op = Op.create(model_desc.stem0_op, affine=affine)
self.stem1_op = Op.create(model_desc.stem1_op, affine=affine)
self._cells = nn.ModuleList()
self.cells = nn.ModuleList()
self._aux_towers = nn.ModuleList()
for i, (cell_desc, aux_tower_desc) in \
@ -32,32 +33,30 @@ class Model(nn.Module):
self._build_cell(cell_desc, aux_tower_desc, droppath, affine)
# adaptive pooling output size to 1x1
self._pool_op = Op.create(model_desc.pool_op, affine=affine)
self.pool_op = Op.create(model_desc.pool_op, affine=affine)
# since ch_p records last cell's output channels
# it indicates the input channel number
self._logits_op = Op.create(model_desc.logits_op, affine=affine)
self.logits_op = Op.create(model_desc.logits_op, affine=affine)
# for i,cell in enumerate(self._cells):
# for i,cell in enumerate(self.cells):
# print(i, ml_utils.param_size(cell))
logger.info({'model_summary': self.summary()})
def _build_cell(self, cell_desc:CellDesc,
aux_tower_desc:Optional[AuxTowerDesc],
droppath:bool, affine:bool)->None:
alphas_cell = None if cell_desc.alphas_from==cell_desc.id \
else self._cells[cell_desc.alphas_from]
else self.cells[cell_desc.alphas_from]
cell = Cell(cell_desc, affine=affine, droppath=droppath,
alphas_cell=alphas_cell)
self._cells.append(cell)
self.cells.append(cell)
self._aux_towers.append(AuxTower(aux_tower_desc) \
if aux_tower_desc else None)
def summary(self)->dict:
return {
'cell_count': len(self._cells),
#'cell_params': [ml_utils.param_size(c) for c in self._cells]
'cell_count': len(self.cells),
#'cell_params': [ml_utils.param_size(c) for c in self.cells]
'params': ml_utils.param_size(self),
'alphas_p': len(list(a for a in self.alphas())),
'alphas': np.sum(a.numel() for a in self.alphas()),
@ -65,31 +64,31 @@ class Model(nn.Module):
}
def alphas(self)->Iterable[nn.Parameter]:
for cell in self._cells:
for cell in self.cells:
if not cell.shared_alphas:
for alpha in cell.alphas():
yield alpha
def weights(self)->Iterable[nn.Parameter]:
for cell in self._cells:
for cell in self.cells:
for w in cell.weights():
yield w
def ops(self)->Iterable[Op]:
for cell in self._cells:
for cell in self.cells:
for op in cell.ops():
yield op
@overrides
def forward(self, x)->Tuple[Tensor, Optional[Tensor]]:
#print(torch.cuda.memory_allocated()/1.0e6)
s0 = self._stem0_op(x)
s0 = self.stem0_op(x)
#print(torch.cuda.memory_allocated()/1.0e6)
s1 = self._stem1_op(x)
s1 = self.stem1_op(x)
#print(-1, s0.shape, s1.shape, torch.cuda.memory_allocated()/1.0e6)
logits_aux = None
for ci, (cell, aux_tower) in enumerate(zip(self._cells, self._aux_towers)):
for ci, (cell, aux_tower) in enumerate(zip(self.cells, self._aux_towers)):
#print(s0.shape, s1.shape, end='')
s0, s1 = s1, cell.forward(s0, s1)
#print(ci, s0.shape, s1.shape, torch.cuda.memory_allocated()/1.0e6)
@ -100,8 +99,8 @@ class Model(nn.Module):
#print(ci, 'aux', logits_aux.shape)
# s1 is now the last cell's output
out = self._pool_op(s1)
logits = self._logits_op(out) # flatten
out = self.pool_op(s1)
logits = self.logits_op(out) # flatten
#print(-1, 'out', out.shape)
#print(-1, 'logits', logits.shape)
@ -111,31 +110,6 @@ class Model(nn.Module):
def device_type(self)->str:
return next(self.parameters()).device.type
def finalize(self, to_cpu=True, restore_device=True)->ModelDesc:
# move model to CPU before finalize because each op will serialize
# its parameters and we don't want copy of these parameters lying on GPU
original = self.device_type()
if to_cpu:
self.cpu()
# finalize will create copy of state and this can overflow GPU RAM
assert self.device_type() == 'cpu'
cell_descs = [cell.finalize() for cell in self._cells]
if restore_device:
self.to(original, non_blocking=True)
return ModelDesc(stem0_op=self._stem0_op.finalize()[0],
stem1_op=self._stem1_op.finalize()[0],
pool_op=self._pool_op.finalize()[0],
ds_ch=self.desc.ds_ch,
n_classes=self.desc.n_classes,
cell_descs=cell_descs,
aux_tower_descs=self.desc.aux_tower_descs,
logits_op=self._logits_op.finalize()[0],
params=self.desc.params)
def drop_path_prob(self, p:float):
""" Set drop path probability
This will be called externally so any DropPath_ modules get
@ -163,9 +137,9 @@ class AuxTower(nn.Module):
nn.BatchNorm2d(768),
nn.ReLU(inplace=True),
)
self._logits_op = nn.Linear(768, aux_tower_desc.n_classes)
self.logits_op = nn.Linear(768, aux_tower_desc.n_classes)
def forward(self, x:torch.Tensor):
x = self.features(x)
x = self._logits_op(x.view(x.size(0), -1))
x = self.logits_op(x.view(x.size(0), -1))
return x

Просмотреть файл

@ -2,6 +2,7 @@ from typing import Iterator, Mapping, Type, Optional, Tuple, List
import math
import copy
import random
import os
import torch
import tensorwatch as tw
@ -11,16 +12,16 @@ import yaml
from archai.common.common import logger
from archai.common.checkpoint import CheckPoint
from archai.common.config import Config
from .cell_builder import CellBuilder
from .arch_trainer import TArchTrainer
from . import nas_utils
from .model_desc import CellType, ModelDesc
from archai.nas.cell_builder import CellBuilder
from archai.nas.arch_trainer import TArchTrainer
from archai.nas import nas_utils
from archai.nas.model_desc import CellType, ModelDesc
from archai.common.trainer import Trainer
from archai.datasets import data
from .model import Model
from archai.nas.model import Model
from archai.common.metrics import EpochMetrics, Metrics
from archai.common import utils
import os
from archai.nas.finalizers import Finalizers
class MetricsStats:
"""Holds model statistics and training metrics for given description"""
@ -336,7 +337,8 @@ class Search:
@staticmethod
def _create_metrics_stats(model:Model, train_metrics:Metrics)->MetricsStats:
finalized = model.finalize(restore_device=False)
finalizers = Finalizers()
finalized = finalizers.finalize_model(model, restore_device=False)
# model stats is doing some hooks so do it last
model_stats = tw.ModelStats(model, [1,3,32,32],# TODO: remove this hard coding
clone_model=True)