зеркало из https://github.com/microsoft/archai.git
add tool to convert checkpoints to onnx.
This commit is contained in:
Родитель
b80150587d
Коммит
190b43a8d8
|
@ -0,0 +1,60 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
from archai.discrete_search.search_spaces.config import ArchConfig
|
||||
from search_space.hgnet import StackedHourglass
|
||||
|
||||
|
||||
def export(checkpoint, model, onnx_file):
|
||||
state_dict = checkpoint['state_dict']
|
||||
# strip 'model.' prefix off the keys!
|
||||
state_dict = dict({(k[6:], state_dict[k]) for k in state_dict})
|
||||
model.load_state_dict(state_dict)
|
||||
input_shapes = [(1, 3, 256, 256)]
|
||||
rand_range = (0.0, 1.0)
|
||||
export_kwargs = {'opset_version': 11}
|
||||
rand_min, rand_max = rand_range
|
||||
sample_inputs = tuple(
|
||||
[
|
||||
((rand_max - rand_min) * torch.rand(*input_shape) + rand_min).type("torch.FloatTensor")
|
||||
for input_shape in input_shapes
|
||||
]
|
||||
)
|
||||
|
||||
torch.onnx.export(
|
||||
model,
|
||||
sample_inputs,
|
||||
onnx_file,
|
||||
input_names=[f"input_{i}" for i in range(len(sample_inputs))],
|
||||
**export_kwargs,
|
||||
)
|
||||
|
||||
print(f'Exported {onnx_file}')
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser(
|
||||
"Converts the final_model.ckpt to final_model.onnx, writing the onnx model to the same folder."
|
||||
)
|
||||
parser.add_argument('arch', type=Path, help="Path to config.json file describing the model architecture")
|
||||
parser.add_argument('--checkpoint', help="Path of the checkpoint to export")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
checkpoint = torch.load(args.checkpoint)
|
||||
|
||||
# get the directory name from args.checkpoint
|
||||
output_path = os.path.dirname(os.path.realpath(args.checkpoint))
|
||||
base_name = os.path.splitext(os.path.basename(args.checkpoint))[0]
|
||||
onnx_file = os.path.join(output_path, f'{base_name}.onnx')
|
||||
|
||||
arch_config = ArchConfig.from_file(args.arch)
|
||||
model = StackedHourglass(arch_config, num_classes=18)
|
||||
export(checkpoint, model, onnx_file)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -1,3 +1,5 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
import os
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
from functools import partial
|
||||
from typing import Tuple, Optional, List
|
||||
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from pathlib import Path
|
||||
from argparse import ArgumentParser
|
||||
|
||||
|
@ -17,8 +20,8 @@ parser.add_argument('--dataset_dir', type=Path, help='Face Synthetics dataset di
|
|||
parser.add_argument('--output_dir', type=Path, help='Output directory.', required=True)
|
||||
parser.add_argument('--lr', type=float, default=2e-4)
|
||||
parser.add_argument('--batch_size', type=int, default=16)
|
||||
parser.add_argument('--epochs', type=int, default=30)
|
||||
parser.add_argument('--val_check_interval', type=float, default=1)
|
||||
parser.add_argument('--epochs', type=int, default=1)
|
||||
parser.add_argument('--val_check_interval', type=float, default=1000)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -67,6 +70,6 @@ if __name__ == '__main__':
|
|||
rand_range = (0.0, 1.0)
|
||||
export_kwargs = {'opset_version': 11}
|
||||
rand_min, rand_max = rand_range
|
||||
sample_input = (rand_max - rand_min) * torch.rand(*input_shape) + rand_min).type("torch.FloatTensor")
|
||||
sample_input = ((rand_max - rand_min) * torch.rand(*input_shape) + rand_min).type("torch.FloatTensor")
|
||||
onnx_file = str(args.output_dir / 'final_model.onnx')
|
||||
torch.onnx.export(model, (sample_input,), onnx_file, input_names=[f"input_0"], **export_kwargs, )
|
||||
|
|
Загрузка…
Ссылка в новой задаче