add tool to convert checkpoints to onnx.

This commit is contained in:
Chris Lovett 2023-03-31 23:50:45 -07:00
Родитель b80150587d
Коммит 190b43a8d8
5 изменённых файлов: 106 добавлений и 37 удалений

Просмотреть файл

@ -0,0 +1,60 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from pathlib import Path
import torch
import os
from argparse import ArgumentParser
from archai.discrete_search.search_spaces.config import ArchConfig
from search_space.hgnet import StackedHourglass
def export(checkpoint, model, onnx_file):
state_dict = checkpoint['state_dict']
# strip 'model.' prefix off the keys!
state_dict = dict({(k[6:], state_dict[k]) for k in state_dict})
model.load_state_dict(state_dict)
input_shapes = [(1, 3, 256, 256)]
rand_range = (0.0, 1.0)
export_kwargs = {'opset_version': 11}
rand_min, rand_max = rand_range
sample_inputs = tuple(
[
((rand_max - rand_min) * torch.rand(*input_shape) + rand_min).type("torch.FloatTensor")
for input_shape in input_shapes
]
)
torch.onnx.export(
model,
sample_inputs,
onnx_file,
input_names=[f"input_{i}" for i in range(len(sample_inputs))],
**export_kwargs,
)
print(f'Exported {onnx_file}')
def main():
parser = ArgumentParser(
"Converts the final_model.ckpt to final_model.onnx, writing the onnx model to the same folder."
)
parser.add_argument('arch', type=Path, help="Path to config.json file describing the model architecture")
parser.add_argument('--checkpoint', help="Path of the checkpoint to export")
args = parser.parse_args()
checkpoint = torch.load(args.checkpoint)
# get the directory name from args.checkpoint
output_path = os.path.dirname(os.path.realpath(args.checkpoint))
base_name = os.path.splitext(os.path.basename(args.checkpoint))[0]
onnx_file = os.path.join(output_path, f'{base_name}.onnx')
arch_config = ArchConfig.from_file(args.arch)
model = StackedHourglass(arch_config, num_classes=18)
export(checkpoint, model, onnx_file)
if __name__ == '__main__':
main()

Просмотреть файл

@ -1,3 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import itertools
from pathlib import Path

Просмотреть файл

@ -1,3 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from functools import partial
from typing import Tuple, Optional, List

Просмотреть файл

@ -1,3 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from functools import partial
from itertools import chain

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from pathlib import Path
from argparse import ArgumentParser
@ -17,8 +20,8 @@ parser.add_argument('--dataset_dir', type=Path, help='Face Synthetics dataset di
parser.add_argument('--output_dir', type=Path, help='Output directory.', required=True)
parser.add_argument('--lr', type=float, default=2e-4)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--epochs', type=int, default=30)
parser.add_argument('--val_check_interval', type=float, default=1)
parser.add_argument('--epochs', type=int, default=1)
parser.add_argument('--val_check_interval', type=float, default=1000)
if __name__ == '__main__':
@ -67,6 +70,6 @@ if __name__ == '__main__':
rand_range = (0.0, 1.0)
export_kwargs = {'opset_version': 11}
rand_min, rand_max = rand_range
sample_input = (rand_max - rand_min) * torch.rand(*input_shape) + rand_min).type("torch.FloatTensor")
sample_input = ((rand_max - rand_min) * torch.rand(*input_shape) + rand_min).type("torch.FloatTensor")
onnx_file = str(args.output_dir / 'final_model.onnx')
torch.onnx.export(model, (sample_input,), onnx_file, input_names=[f"input_0"], **export_kwargs, )