зеркало из https://github.com/microsoft/archai.git
add tool to convert checkpoints to onnx.
This commit is contained in:
Родитель
b80150587d
Коммит
190b43a8d8
|
@ -0,0 +1,60 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
from archai.discrete_search.search_spaces.config import ArchConfig
|
||||
from search_space.hgnet import StackedHourglass
|
||||
|
||||
|
||||
def export(checkpoint, model, onnx_file):
|
||||
state_dict = checkpoint['state_dict']
|
||||
# strip 'model.' prefix off the keys!
|
||||
state_dict = dict({(k[6:], state_dict[k]) for k in state_dict})
|
||||
model.load_state_dict(state_dict)
|
||||
input_shapes = [(1, 3, 256, 256)]
|
||||
rand_range = (0.0, 1.0)
|
||||
export_kwargs = {'opset_version': 11}
|
||||
rand_min, rand_max = rand_range
|
||||
sample_inputs = tuple(
|
||||
[
|
||||
((rand_max - rand_min) * torch.rand(*input_shape) + rand_min).type("torch.FloatTensor")
|
||||
for input_shape in input_shapes
|
||||
]
|
||||
)
|
||||
|
||||
torch.onnx.export(
|
||||
model,
|
||||
sample_inputs,
|
||||
onnx_file,
|
||||
input_names=[f"input_{i}" for i in range(len(sample_inputs))],
|
||||
**export_kwargs,
|
||||
)
|
||||
|
||||
print(f'Exported {onnx_file}')
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser(
|
||||
"Converts the final_model.ckpt to final_model.onnx, writing the onnx model to the same folder."
|
||||
)
|
||||
parser.add_argument('arch', type=Path, help="Path to config.json file describing the model architecture")
|
||||
parser.add_argument('--checkpoint', help="Path of the checkpoint to export")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
checkpoint = torch.load(args.checkpoint)
|
||||
|
||||
# get the directory name from args.checkpoint
|
||||
output_path = os.path.dirname(os.path.realpath(args.checkpoint))
|
||||
base_name = os.path.splitext(os.path.basename(args.checkpoint))[0]
|
||||
onnx_file = os.path.join(output_path, f'{base_name}.onnx')
|
||||
|
||||
arch_config = ArchConfig.from_file(args.arch)
|
||||
model = StackedHourglass(arch_config, num_classes=18)
|
||||
export(checkpoint, model, onnx_file)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -1,3 +1,5 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
import os
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
|
@ -93,7 +95,7 @@ if __name__ == '__main__':
|
|||
]
|
||||
|
||||
evaluator = RemoteAzureBenchmarkEvaluator(
|
||||
input_shape=input_shape,
|
||||
input_shape=input_shape,
|
||||
onnx_export_kwargs={'opset_version': 11},
|
||||
**target_config
|
||||
)
|
||||
|
@ -121,7 +123,7 @@ if __name__ == '__main__':
|
|||
|
||||
if not args.serial_training:
|
||||
partial_tr_obj = RayParallelEvaluator(
|
||||
partial_tr_obj, num_gpus=args.gpus_per_job,
|
||||
partial_tr_obj, num_gpus=args.gpus_per_job,
|
||||
max_calls=1
|
||||
)
|
||||
|
||||
|
@ -138,7 +140,7 @@ if __name__ == '__main__':
|
|||
# Search algorithm
|
||||
algo_config = search_config['algorithm']
|
||||
algo = AVAILABLE_ALGOS[algo_config['name']](
|
||||
search_space, so, dataset_provider,
|
||||
search_space, so, dataset_provider,
|
||||
output_dir=args.output_dir, seed=args.seed,
|
||||
**algo_config.get('params', {}),
|
||||
)
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
from functools import partial
|
||||
from typing import Tuple, Optional, List
|
||||
|
||||
|
@ -21,31 +23,31 @@ def hgnet_param_tree_factory(stem_strides: Tuple[int, ...] = (2, 4),
|
|||
skip_block_max_ops: int = 3,
|
||||
upsample_block_max_ops: int = 4):
|
||||
assert num_blocks > 1, 'num_blocks must be greater than 1'
|
||||
|
||||
|
||||
return ArchParamTree({
|
||||
'stem_stride': DiscreteChoice(stem_strides),
|
||||
'stem_stride': DiscreteChoice(stem_strides),
|
||||
'base_ch': DiscreteChoice(base_channels),
|
||||
|
||||
|
||||
'hourglasses': repeat_config({
|
||||
|
||||
'downsample_blocks': repeat_config({
|
||||
'layers': repeat_config({
|
||||
'op': DiscreteChoice(op_subset)
|
||||
}, repeat_times=range(1, downsample_block_max_ops + 1), share_arch=False),
|
||||
|
||||
|
||||
'ch_expansion_factor': DiscreteChoice([1.0, 1.2, 1.5, 1.6, 2.0, 2.2]),
|
||||
}, repeat_times=num_blocks),
|
||||
|
||||
|
||||
'skip_blocks': repeat_config({
|
||||
'layers': repeat_config({
|
||||
'op': DiscreteChoice(op_subset)
|
||||
}, repeat_times=range(0, skip_block_max_ops+1), share_arch=False),
|
||||
}, repeat_times=range(0, skip_block_max_ops+1), share_arch=False),
|
||||
}, repeat_times=num_blocks-1),
|
||||
|
||||
|
||||
'upsample_blocks': repeat_config({
|
||||
'layers': repeat_config({
|
||||
'op': DiscreteChoice(op_subset)
|
||||
}, repeat_times=range(1, upsample_block_max_ops+1), share_arch=False),
|
||||
}, repeat_times=range(1, upsample_block_max_ops+1), share_arch=False),
|
||||
}, repeat_times=num_blocks-1),
|
||||
}, repeat_times=range(1, max_num_hourglass+1), share_arch=share_hourglass_arch),
|
||||
|
||||
|
@ -62,20 +64,20 @@ class Hourglass(nn.Module):
|
|||
|
||||
self.upsample = nn.UpsamplingBilinear2d(scale_factor=2)
|
||||
self.chs = [self.base_channels]
|
||||
|
||||
|
||||
# Calculates channels on each branch
|
||||
for block_cfg in arch_config.pick('downsample_blocks'):
|
||||
for block_cfg in arch_config.pick('downsample_blocks'):
|
||||
self.chs.append(
|
||||
int(self.chs[-1] * block_cfg.pick('ch_expansion_factor'))
|
||||
)
|
||||
|
||||
|
||||
self.nb_blocks = len(self.chs) - 1
|
||||
|
||||
# Downsample blocks
|
||||
self.down_blocks = nn.ModuleList()
|
||||
for block_idx, block_cfg in enumerate(arch_config.pick('downsample_blocks')):
|
||||
in_ch, out_ch = self.chs[block_idx], self.chs[block_idx + 1]
|
||||
|
||||
|
||||
down_block = [
|
||||
OPS[layer_cfg.pick('op')](
|
||||
(in_ch if layer_idx == 0 else out_ch),
|
||||
|
@ -84,21 +86,21 @@ class Hourglass(nn.Module):
|
|||
)
|
||||
for layer_idx, layer_cfg in enumerate(block_cfg.pick('layers'))
|
||||
]
|
||||
|
||||
|
||||
self.down_blocks.append(nn.Sequential(*down_block))
|
||||
|
||||
|
||||
# Skip blocks
|
||||
self.skip_blocks = nn.ModuleList()
|
||||
for block_idx, block_cfg in enumerate(arch_config.pick('skip_blocks')):
|
||||
out_ch = self.chs[block_idx + 1]
|
||||
|
||||
|
||||
skip_block = [
|
||||
OPS.get(layer_cfg.pick('op'))(out_ch, out_ch)
|
||||
for layer_idx, layer_cfg in enumerate(block_cfg.pick('layers'))
|
||||
]
|
||||
|
||||
|
||||
self.skip_blocks.append(nn.Sequential(*skip_block))
|
||||
|
||||
|
||||
# Upsample blocks
|
||||
self.up_blocks = nn.ModuleList()
|
||||
for block_idx, block_cfg in enumerate(arch_config.pick('upsample_blocks')):
|
||||
|
@ -124,7 +126,7 @@ class Hourglass(nn.Module):
|
|||
out = self.down_blocks[i](inp)
|
||||
skip_connections[i] = self.skip_blocks[i](out)
|
||||
inp = out
|
||||
|
||||
|
||||
# Last downsample branch
|
||||
out = self.down_blocks[-1](inp)
|
||||
|
||||
|
@ -137,24 +139,24 @@ class Hourglass(nn.Module):
|
|||
class StackedHourglass(nn.Module):
|
||||
def __init__(self, arch_config: ArchConfig, num_classes: int, in_channels: int = 3):
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.in_channels = in_channels
|
||||
self.arch_config = arch_config
|
||||
self.base_channels = arch_config.pick('base_ch')
|
||||
|
||||
|
||||
# Classifier
|
||||
self.classifier = nn.Conv2d(self.base_channels, num_classes, kernel_size=1)
|
||||
|
||||
# Stem convolution
|
||||
self.stem_stride = arch_config.pick('stem_stride')
|
||||
self.stem_conv = ReluConv2d(
|
||||
in_channels=in_channels, out_channels=self.base_channels,
|
||||
in_channels=in_channels, out_channels=self.base_channels,
|
||||
stride=self.stem_stride
|
||||
)
|
||||
|
||||
|
||||
self.final_upsample = nn.UpsamplingBilinear2d(scale_factor=self.stem_stride)
|
||||
|
||||
|
||||
self.hgs = nn.Sequential(*[
|
||||
Hourglass(hg_conf, self.base_channels)
|
||||
for hg_conf in arch_config.pick('hourglasses')
|
||||
|
@ -166,7 +168,7 @@ class StackedHourglass(nn.Module):
|
|||
])
|
||||
|
||||
self.classifier = Conv2dSamePadding(self.base_channels, num_classes, kernel_size=1)
|
||||
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
out = self.stem_conv(x)
|
||||
out = self.hgs(out)
|
||||
|
@ -176,31 +178,31 @@ class StackedHourglass(nn.Module):
|
|||
|
||||
|
||||
class HgnetSegmentationSearchSpace(ConfigSearchSpace):
|
||||
def __init__(self,
|
||||
def __init__(self,
|
||||
num_classes: int,
|
||||
img_size: Tuple[int, int],
|
||||
in_channels: int = 3,
|
||||
op_subset: Tuple[str, ...] = ('conv3x3', 'conv5x5', 'conv7x7'),
|
||||
stem_strides: Tuple[int, ...] = (1, 2, 4),
|
||||
num_blocks: int = 4,
|
||||
num_blocks: int = 4,
|
||||
downsample_block_max_ops: int = 4,
|
||||
skip_block_max_ops: int = 2,
|
||||
upsample_block_max_ops: int = 4,
|
||||
post_upsample_max_ops: int = 3,
|
||||
**ss_kwargs):
|
||||
|
||||
|
||||
possible_downsample_factors = [
|
||||
2**num_blocks * stem_stride for stem_stride in stem_strides
|
||||
]
|
||||
|
||||
w, h = img_size
|
||||
|
||||
|
||||
assert all(w % d_factor == 0 for d_factor in possible_downsample_factors), \
|
||||
f'Image width must be divisible by all possible downsample factors ({2**num_blocks} * stem_stride)'
|
||||
|
||||
|
||||
assert all(h % d_factor == 0 for d_factor in possible_downsample_factors), \
|
||||
f'Image height must be divisible by all possible downsample factors ({2**num_blocks} * stem_stride)'
|
||||
|
||||
|
||||
ss_kwargs['builder_kwargs'] = {
|
||||
'op_subset': op_subset,
|
||||
'stem_strides': stem_strides,
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
|
||||
|
@ -23,7 +25,7 @@ class Conv2dSamePadding(nn.Conv2d):
|
|||
|
||||
class ReluConv2d(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int,
|
||||
kernel_size: int = 3, stride: int = 1,
|
||||
kernel_size: int = 3, stride: int = 1,
|
||||
bias: bool = False, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from pathlib import Path
|
||||
from argparse import ArgumentParser
|
||||
|
||||
|
@ -17,8 +20,8 @@ parser.add_argument('--dataset_dir', type=Path, help='Face Synthetics dataset di
|
|||
parser.add_argument('--output_dir', type=Path, help='Output directory.', required=True)
|
||||
parser.add_argument('--lr', type=float, default=2e-4)
|
||||
parser.add_argument('--batch_size', type=int, default=16)
|
||||
parser.add_argument('--epochs', type=int, default=30)
|
||||
parser.add_argument('--val_check_interval', type=float, default=1)
|
||||
parser.add_argument('--epochs', type=int, default=1)
|
||||
parser.add_argument('--val_check_interval', type=float, default=1000)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -67,6 +70,6 @@ if __name__ == '__main__':
|
|||
rand_range = (0.0, 1.0)
|
||||
export_kwargs = {'opset_version': 11}
|
||||
rand_min, rand_max = rand_range
|
||||
sample_input = (rand_max - rand_min) * torch.rand(*input_shape) + rand_min).type("torch.FloatTensor")
|
||||
sample_input = ((rand_max - rand_min) * torch.rand(*input_shape) + rand_min).type("torch.FloatTensor")
|
||||
onnx_file = str(args.output_dir / 'final_model.onnx')
|
||||
torch.onnx.export(model, (sample_input,), onnx_file, input_names=[f"input_0"], **export_kwargs, )
|
||||
|
|
Загрузка…
Ссылка в новой задаче