add tool to convert checkpoints to onnx.

This commit is contained in:
Chris Lovett 2023-03-31 23:50:45 -07:00
Родитель b80150587d
Коммит 190b43a8d8
5 изменённых файлов: 106 добавлений и 37 удалений

Просмотреть файл

@ -0,0 +1,60 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from pathlib import Path
import torch
import os
from argparse import ArgumentParser
from archai.discrete_search.search_spaces.config import ArchConfig
from search_space.hgnet import StackedHourglass
def export(checkpoint, model, onnx_file):
state_dict = checkpoint['state_dict']
# strip 'model.' prefix off the keys!
state_dict = dict({(k[6:], state_dict[k]) for k in state_dict})
model.load_state_dict(state_dict)
input_shapes = [(1, 3, 256, 256)]
rand_range = (0.0, 1.0)
export_kwargs = {'opset_version': 11}
rand_min, rand_max = rand_range
sample_inputs = tuple(
[
((rand_max - rand_min) * torch.rand(*input_shape) + rand_min).type("torch.FloatTensor")
for input_shape in input_shapes
]
)
torch.onnx.export(
model,
sample_inputs,
onnx_file,
input_names=[f"input_{i}" for i in range(len(sample_inputs))],
**export_kwargs,
)
print(f'Exported {onnx_file}')
def main():
parser = ArgumentParser(
"Converts the final_model.ckpt to final_model.onnx, writing the onnx model to the same folder."
)
parser.add_argument('arch', type=Path, help="Path to config.json file describing the model architecture")
parser.add_argument('--checkpoint', help="Path of the checkpoint to export")
args = parser.parse_args()
checkpoint = torch.load(args.checkpoint)
# get the directory name from args.checkpoint
output_path = os.path.dirname(os.path.realpath(args.checkpoint))
base_name = os.path.splitext(os.path.basename(args.checkpoint))[0]
onnx_file = os.path.join(output_path, f'{base_name}.onnx')
arch_config = ArchConfig.from_file(args.arch)
model = StackedHourglass(arch_config, num_classes=18)
export(checkpoint, model, onnx_file)
if __name__ == '__main__':
main()

Просмотреть файл

@ -1,3 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import itertools
from pathlib import Path
@ -93,7 +95,7 @@ if __name__ == '__main__':
]
evaluator = RemoteAzureBenchmarkEvaluator(
input_shape=input_shape,
input_shape=input_shape,
onnx_export_kwargs={'opset_version': 11},
**target_config
)
@ -121,7 +123,7 @@ if __name__ == '__main__':
if not args.serial_training:
partial_tr_obj = RayParallelEvaluator(
partial_tr_obj, num_gpus=args.gpus_per_job,
partial_tr_obj, num_gpus=args.gpus_per_job,
max_calls=1
)
@ -138,7 +140,7 @@ if __name__ == '__main__':
# Search algorithm
algo_config = search_config['algorithm']
algo = AVAILABLE_ALGOS[algo_config['name']](
search_space, so, dataset_provider,
search_space, so, dataset_provider,
output_dir=args.output_dir, seed=args.seed,
**algo_config.get('params', {}),
)

Просмотреть файл

@ -1,3 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from functools import partial
from typing import Tuple, Optional, List
@ -21,31 +23,31 @@ def hgnet_param_tree_factory(stem_strides: Tuple[int, ...] = (2, 4),
skip_block_max_ops: int = 3,
upsample_block_max_ops: int = 4):
assert num_blocks > 1, 'num_blocks must be greater than 1'
return ArchParamTree({
'stem_stride': DiscreteChoice(stem_strides),
'stem_stride': DiscreteChoice(stem_strides),
'base_ch': DiscreteChoice(base_channels),
'hourglasses': repeat_config({
'downsample_blocks': repeat_config({
'layers': repeat_config({
'op': DiscreteChoice(op_subset)
}, repeat_times=range(1, downsample_block_max_ops + 1), share_arch=False),
'ch_expansion_factor': DiscreteChoice([1.0, 1.2, 1.5, 1.6, 2.0, 2.2]),
}, repeat_times=num_blocks),
'skip_blocks': repeat_config({
'layers': repeat_config({
'op': DiscreteChoice(op_subset)
}, repeat_times=range(0, skip_block_max_ops+1), share_arch=False),
}, repeat_times=range(0, skip_block_max_ops+1), share_arch=False),
}, repeat_times=num_blocks-1),
'upsample_blocks': repeat_config({
'layers': repeat_config({
'op': DiscreteChoice(op_subset)
}, repeat_times=range(1, upsample_block_max_ops+1), share_arch=False),
}, repeat_times=range(1, upsample_block_max_ops+1), share_arch=False),
}, repeat_times=num_blocks-1),
}, repeat_times=range(1, max_num_hourglass+1), share_arch=share_hourglass_arch),
@ -62,20 +64,20 @@ class Hourglass(nn.Module):
self.upsample = nn.UpsamplingBilinear2d(scale_factor=2)
self.chs = [self.base_channels]
# Calculates channels on each branch
for block_cfg in arch_config.pick('downsample_blocks'):
for block_cfg in arch_config.pick('downsample_blocks'):
self.chs.append(
int(self.chs[-1] * block_cfg.pick('ch_expansion_factor'))
)
self.nb_blocks = len(self.chs) - 1
# Downsample blocks
self.down_blocks = nn.ModuleList()
for block_idx, block_cfg in enumerate(arch_config.pick('downsample_blocks')):
in_ch, out_ch = self.chs[block_idx], self.chs[block_idx + 1]
down_block = [
OPS[layer_cfg.pick('op')](
(in_ch if layer_idx == 0 else out_ch),
@ -84,21 +86,21 @@ class Hourglass(nn.Module):
)
for layer_idx, layer_cfg in enumerate(block_cfg.pick('layers'))
]
self.down_blocks.append(nn.Sequential(*down_block))
# Skip blocks
self.skip_blocks = nn.ModuleList()
for block_idx, block_cfg in enumerate(arch_config.pick('skip_blocks')):
out_ch = self.chs[block_idx + 1]
skip_block = [
OPS.get(layer_cfg.pick('op'))(out_ch, out_ch)
for layer_idx, layer_cfg in enumerate(block_cfg.pick('layers'))
]
self.skip_blocks.append(nn.Sequential(*skip_block))
# Upsample blocks
self.up_blocks = nn.ModuleList()
for block_idx, block_cfg in enumerate(arch_config.pick('upsample_blocks')):
@ -124,7 +126,7 @@ class Hourglass(nn.Module):
out = self.down_blocks[i](inp)
skip_connections[i] = self.skip_blocks[i](out)
inp = out
# Last downsample branch
out = self.down_blocks[-1](inp)
@ -137,24 +139,24 @@ class Hourglass(nn.Module):
class StackedHourglass(nn.Module):
def __init__(self, arch_config: ArchConfig, num_classes: int, in_channels: int = 3):
super().__init__()
self.num_classes = num_classes
self.in_channels = in_channels
self.arch_config = arch_config
self.base_channels = arch_config.pick('base_ch')
# Classifier
self.classifier = nn.Conv2d(self.base_channels, num_classes, kernel_size=1)
# Stem convolution
self.stem_stride = arch_config.pick('stem_stride')
self.stem_conv = ReluConv2d(
in_channels=in_channels, out_channels=self.base_channels,
in_channels=in_channels, out_channels=self.base_channels,
stride=self.stem_stride
)
self.final_upsample = nn.UpsamplingBilinear2d(scale_factor=self.stem_stride)
self.hgs = nn.Sequential(*[
Hourglass(hg_conf, self.base_channels)
for hg_conf in arch_config.pick('hourglasses')
@ -166,7 +168,7 @@ class StackedHourglass(nn.Module):
])
self.classifier = Conv2dSamePadding(self.base_channels, num_classes, kernel_size=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self.stem_conv(x)
out = self.hgs(out)
@ -176,31 +178,31 @@ class StackedHourglass(nn.Module):
class HgnetSegmentationSearchSpace(ConfigSearchSpace):
def __init__(self,
def __init__(self,
num_classes: int,
img_size: Tuple[int, int],
in_channels: int = 3,
op_subset: Tuple[str, ...] = ('conv3x3', 'conv5x5', 'conv7x7'),
stem_strides: Tuple[int, ...] = (1, 2, 4),
num_blocks: int = 4,
num_blocks: int = 4,
downsample_block_max_ops: int = 4,
skip_block_max_ops: int = 2,
upsample_block_max_ops: int = 4,
post_upsample_max_ops: int = 3,
**ss_kwargs):
possible_downsample_factors = [
2**num_blocks * stem_stride for stem_stride in stem_strides
]
w, h = img_size
assert all(w % d_factor == 0 for d_factor in possible_downsample_factors), \
f'Image width must be divisible by all possible downsample factors ({2**num_blocks} * stem_stride)'
assert all(h % d_factor == 0 for d_factor in possible_downsample_factors), \
f'Image height must be divisible by all possible downsample factors ({2**num_blocks} * stem_stride)'
ss_kwargs['builder_kwargs'] = {
'op_subset': op_subset,
'stem_strides': stem_strides,

Просмотреть файл

@ -1,3 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from functools import partial
from itertools import chain
@ -23,7 +25,7 @@ class Conv2dSamePadding(nn.Conv2d):
class ReluConv2d(nn.Module):
def __init__(self, in_channels: int, out_channels: int,
kernel_size: int = 3, stride: int = 1,
kernel_size: int = 3, stride: int = 1,
bias: bool = False, **kwargs):
super().__init__()

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from pathlib import Path
from argparse import ArgumentParser
@ -17,8 +20,8 @@ parser.add_argument('--dataset_dir', type=Path, help='Face Synthetics dataset di
parser.add_argument('--output_dir', type=Path, help='Output directory.', required=True)
parser.add_argument('--lr', type=float, default=2e-4)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--epochs', type=int, default=30)
parser.add_argument('--val_check_interval', type=float, default=1)
parser.add_argument('--epochs', type=int, default=1)
parser.add_argument('--val_check_interval', type=float, default=1000)
if __name__ == '__main__':
@ -67,6 +70,6 @@ if __name__ == '__main__':
rand_range = (0.0, 1.0)
export_kwargs = {'opset_version': 11}
rand_min, rand_max = rand_range
sample_input = (rand_max - rand_min) * torch.rand(*input_shape) + rand_min).type("torch.FloatTensor")
sample_input = ((rand_max - rand_min) * torch.rand(*input_shape) + rand_min).type("torch.FloatTensor")
onnx_file = str(args.output_dir / 'final_model.onnx')
torch.onnx.export(model, (sample_input,), onnx_file, input_names=[f"input_0"], **export_kwargs, )