From 190b43a8d88899dab000e1c63d17f35a664ec3ff Mon Sep 17 00:00:00 2001 From: Chris Lovett Date: Fri, 31 Mar 2023 23:50:45 -0700 Subject: [PATCH] add tool to convert checkpoints to onnx. --- tasks/face_segmentation/export.py | 60 ++++++++++++++++++ tasks/face_segmentation/search.py | 8 ++- tasks/face_segmentation/search_space/hgnet.py | 62 ++++++++++--------- tasks/face_segmentation/search_space/ops.py | 4 +- tasks/face_segmentation/train.py | 9 ++- 5 files changed, 106 insertions(+), 37 deletions(-) create mode 100644 tasks/face_segmentation/export.py diff --git a/tasks/face_segmentation/export.py b/tasks/face_segmentation/export.py new file mode 100644 index 00000000..c2730d61 --- /dev/null +++ b/tasks/face_segmentation/export.py @@ -0,0 +1,60 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +from pathlib import Path +import torch +import os +from argparse import ArgumentParser +from archai.discrete_search.search_spaces.config import ArchConfig +from search_space.hgnet import StackedHourglass + + +def export(checkpoint, model, onnx_file): + state_dict = checkpoint['state_dict'] + # strip 'model.' prefix off the keys! + state_dict = dict({(k[6:], state_dict[k]) for k in state_dict}) + model.load_state_dict(state_dict) + input_shapes = [(1, 3, 256, 256)] + rand_range = (0.0, 1.0) + export_kwargs = {'opset_version': 11} + rand_min, rand_max = rand_range + sample_inputs = tuple( + [ + ((rand_max - rand_min) * torch.rand(*input_shape) + rand_min).type("torch.FloatTensor") + for input_shape in input_shapes + ] + ) + + torch.onnx.export( + model, + sample_inputs, + onnx_file, + input_names=[f"input_{i}" for i in range(len(sample_inputs))], + **export_kwargs, + ) + + print(f'Exported {onnx_file}') + + +def main(): + parser = ArgumentParser( + "Converts the final_model.ckpt to final_model.onnx, writing the onnx model to the same folder." + ) + parser.add_argument('arch', type=Path, help="Path to config.json file describing the model architecture") + parser.add_argument('--checkpoint', help="Path of the checkpoint to export") + + args = parser.parse_args() + + checkpoint = torch.load(args.checkpoint) + + # get the directory name from args.checkpoint + output_path = os.path.dirname(os.path.realpath(args.checkpoint)) + base_name = os.path.splitext(os.path.basename(args.checkpoint))[0] + onnx_file = os.path.join(output_path, f'{base_name}.onnx') + + arch_config = ArchConfig.from_file(args.arch) + model = StackedHourglass(arch_config, num_classes=18) + export(checkpoint, model, onnx_file) + + +if __name__ == '__main__': + main() diff --git a/tasks/face_segmentation/search.py b/tasks/face_segmentation/search.py index 117b4aa8..3ee0e591 100644 --- a/tasks/face_segmentation/search.py +++ b/tasks/face_segmentation/search.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. import os import itertools from pathlib import Path @@ -93,7 +95,7 @@ if __name__ == '__main__': ] evaluator = RemoteAzureBenchmarkEvaluator( - input_shape=input_shape, + input_shape=input_shape, onnx_export_kwargs={'opset_version': 11}, **target_config ) @@ -121,7 +123,7 @@ if __name__ == '__main__': if not args.serial_training: partial_tr_obj = RayParallelEvaluator( - partial_tr_obj, num_gpus=args.gpus_per_job, + partial_tr_obj, num_gpus=args.gpus_per_job, max_calls=1 ) @@ -138,7 +140,7 @@ if __name__ == '__main__': # Search algorithm algo_config = search_config['algorithm'] algo = AVAILABLE_ALGOS[algo_config['name']]( - search_space, so, dataset_provider, + search_space, so, dataset_provider, output_dir=args.output_dir, seed=args.seed, **algo_config.get('params', {}), ) diff --git a/tasks/face_segmentation/search_space/hgnet.py b/tasks/face_segmentation/search_space/hgnet.py index dae60bb7..fa5cc7f3 100644 --- a/tasks/face_segmentation/search_space/hgnet.py +++ b/tasks/face_segmentation/search_space/hgnet.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. from functools import partial from typing import Tuple, Optional, List @@ -21,31 +23,31 @@ def hgnet_param_tree_factory(stem_strides: Tuple[int, ...] = (2, 4), skip_block_max_ops: int = 3, upsample_block_max_ops: int = 4): assert num_blocks > 1, 'num_blocks must be greater than 1' - + return ArchParamTree({ - 'stem_stride': DiscreteChoice(stem_strides), + 'stem_stride': DiscreteChoice(stem_strides), 'base_ch': DiscreteChoice(base_channels), - + 'hourglasses': repeat_config({ 'downsample_blocks': repeat_config({ 'layers': repeat_config({ 'op': DiscreteChoice(op_subset) }, repeat_times=range(1, downsample_block_max_ops + 1), share_arch=False), - + 'ch_expansion_factor': DiscreteChoice([1.0, 1.2, 1.5, 1.6, 2.0, 2.2]), }, repeat_times=num_blocks), - + 'skip_blocks': repeat_config({ 'layers': repeat_config({ 'op': DiscreteChoice(op_subset) - }, repeat_times=range(0, skip_block_max_ops+1), share_arch=False), + }, repeat_times=range(0, skip_block_max_ops+1), share_arch=False), }, repeat_times=num_blocks-1), - + 'upsample_blocks': repeat_config({ 'layers': repeat_config({ 'op': DiscreteChoice(op_subset) - }, repeat_times=range(1, upsample_block_max_ops+1), share_arch=False), + }, repeat_times=range(1, upsample_block_max_ops+1), share_arch=False), }, repeat_times=num_blocks-1), }, repeat_times=range(1, max_num_hourglass+1), share_arch=share_hourglass_arch), @@ -62,20 +64,20 @@ class Hourglass(nn.Module): self.upsample = nn.UpsamplingBilinear2d(scale_factor=2) self.chs = [self.base_channels] - + # Calculates channels on each branch - for block_cfg in arch_config.pick('downsample_blocks'): + for block_cfg in arch_config.pick('downsample_blocks'): self.chs.append( int(self.chs[-1] * block_cfg.pick('ch_expansion_factor')) ) - + self.nb_blocks = len(self.chs) - 1 # Downsample blocks self.down_blocks = nn.ModuleList() for block_idx, block_cfg in enumerate(arch_config.pick('downsample_blocks')): in_ch, out_ch = self.chs[block_idx], self.chs[block_idx + 1] - + down_block = [ OPS[layer_cfg.pick('op')]( (in_ch if layer_idx == 0 else out_ch), @@ -84,21 +86,21 @@ class Hourglass(nn.Module): ) for layer_idx, layer_cfg in enumerate(block_cfg.pick('layers')) ] - + self.down_blocks.append(nn.Sequential(*down_block)) - + # Skip blocks self.skip_blocks = nn.ModuleList() for block_idx, block_cfg in enumerate(arch_config.pick('skip_blocks')): out_ch = self.chs[block_idx + 1] - + skip_block = [ OPS.get(layer_cfg.pick('op'))(out_ch, out_ch) for layer_idx, layer_cfg in enumerate(block_cfg.pick('layers')) ] - + self.skip_blocks.append(nn.Sequential(*skip_block)) - + # Upsample blocks self.up_blocks = nn.ModuleList() for block_idx, block_cfg in enumerate(arch_config.pick('upsample_blocks')): @@ -124,7 +126,7 @@ class Hourglass(nn.Module): out = self.down_blocks[i](inp) skip_connections[i] = self.skip_blocks[i](out) inp = out - + # Last downsample branch out = self.down_blocks[-1](inp) @@ -137,24 +139,24 @@ class Hourglass(nn.Module): class StackedHourglass(nn.Module): def __init__(self, arch_config: ArchConfig, num_classes: int, in_channels: int = 3): super().__init__() - + self.num_classes = num_classes self.in_channels = in_channels self.arch_config = arch_config self.base_channels = arch_config.pick('base_ch') - + # Classifier self.classifier = nn.Conv2d(self.base_channels, num_classes, kernel_size=1) # Stem convolution self.stem_stride = arch_config.pick('stem_stride') self.stem_conv = ReluConv2d( - in_channels=in_channels, out_channels=self.base_channels, + in_channels=in_channels, out_channels=self.base_channels, stride=self.stem_stride ) - + self.final_upsample = nn.UpsamplingBilinear2d(scale_factor=self.stem_stride) - + self.hgs = nn.Sequential(*[ Hourglass(hg_conf, self.base_channels) for hg_conf in arch_config.pick('hourglasses') @@ -166,7 +168,7 @@ class StackedHourglass(nn.Module): ]) self.classifier = Conv2dSamePadding(self.base_channels, num_classes, kernel_size=1) - + def forward(self, x: torch.Tensor) -> torch.Tensor: out = self.stem_conv(x) out = self.hgs(out) @@ -176,31 +178,31 @@ class StackedHourglass(nn.Module): class HgnetSegmentationSearchSpace(ConfigSearchSpace): - def __init__(self, + def __init__(self, num_classes: int, img_size: Tuple[int, int], in_channels: int = 3, op_subset: Tuple[str, ...] = ('conv3x3', 'conv5x5', 'conv7x7'), stem_strides: Tuple[int, ...] = (1, 2, 4), - num_blocks: int = 4, + num_blocks: int = 4, downsample_block_max_ops: int = 4, skip_block_max_ops: int = 2, upsample_block_max_ops: int = 4, post_upsample_max_ops: int = 3, **ss_kwargs): - + possible_downsample_factors = [ 2**num_blocks * stem_stride for stem_stride in stem_strides ] w, h = img_size - + assert all(w % d_factor == 0 for d_factor in possible_downsample_factors), \ f'Image width must be divisible by all possible downsample factors ({2**num_blocks} * stem_stride)' - + assert all(h % d_factor == 0 for d_factor in possible_downsample_factors), \ f'Image height must be divisible by all possible downsample factors ({2**num_blocks} * stem_stride)' - + ss_kwargs['builder_kwargs'] = { 'op_subset': op_subset, 'stem_strides': stem_strides, diff --git a/tasks/face_segmentation/search_space/ops.py b/tasks/face_segmentation/search_space/ops.py index daa41b8f..fa5e31a1 100644 --- a/tasks/face_segmentation/search_space/ops.py +++ b/tasks/face_segmentation/search_space/ops.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. from functools import partial from itertools import chain @@ -23,7 +25,7 @@ class Conv2dSamePadding(nn.Conv2d): class ReluConv2d(nn.Module): def __init__(self, in_channels: int, out_channels: int, - kernel_size: int = 3, stride: int = 1, + kernel_size: int = 3, stride: int = 1, bias: bool = False, **kwargs): super().__init__() diff --git a/tasks/face_segmentation/train.py b/tasks/face_segmentation/train.py index 82a52d52..8b13a569 100644 --- a/tasks/face_segmentation/train.py +++ b/tasks/face_segmentation/train.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + from pathlib import Path from argparse import ArgumentParser @@ -17,8 +20,8 @@ parser.add_argument('--dataset_dir', type=Path, help='Face Synthetics dataset di parser.add_argument('--output_dir', type=Path, help='Output directory.', required=True) parser.add_argument('--lr', type=float, default=2e-4) parser.add_argument('--batch_size', type=int, default=16) -parser.add_argument('--epochs', type=int, default=30) -parser.add_argument('--val_check_interval', type=float, default=1) +parser.add_argument('--epochs', type=int, default=1) +parser.add_argument('--val_check_interval', type=float, default=1000) if __name__ == '__main__': @@ -67,6 +70,6 @@ if __name__ == '__main__': rand_range = (0.0, 1.0) export_kwargs = {'opset_version': 11} rand_min, rand_max = rand_range - sample_input = (rand_max - rand_min) * torch.rand(*input_shape) + rand_min).type("torch.FloatTensor") + sample_input = ((rand_max - rand_min) * torch.rand(*input_shape) + rand_min).type("torch.FloatTensor") onnx_file = str(args.output_dir / 'final_model.onnx') torch.onnx.export(model, (sample_input,), onnx_file, input_names=[f"input_0"], **export_kwargs, )