CNTK/PretrainedModels/download_model.py

78 строки
4.6 KiB
Python

# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license. See LICENSE.md file in the project root
# for full license information.
# ==============================================================================
from __future__ import print_function
import os
import sys
try:
from urllib.request import urlretrieve
except ImportError:
from urllib import urlretrieve
# Add models here like this: (category, model_name, model_url)
models = (('Image Classification', 'AlexNet_ImageNet_CNTK', 'https://www.cntk.ai/Models/CNTK_Pretrained/AlexNet_ImageNet_CNTK.model'),
('Image Classification', 'AlexNet_ImageNet_Caffe', 'https://www.cntk.ai/Models/Caffe_Converted/AlexNet_ImageNet_Caffe.model'),
('Image Classification', 'InceptionV3_ImageNet_CNTK', 'https://www.cntk.ai/Models/CNTK_Pretrained/InceptionV3_ImageNet_CNTK.model'),
('Image Classification', 'BNInception_ImageNet_Caffe', 'https://www.cntk.ai/Models/Caffe_Converted/BNInception_ImageNet_Caffe.model'),
('Image Classification', 'ResNet18_ImageNet_CNTK', 'https://www.cntk.ai/Models/CNTK_Pretrained/ResNet18_ImageNet_CNTK.model'),
('Image Classification', 'ResNet_18', 'https://cntkbuildstorage.blob.core.windows.net/cntk-pretrained-model/ResNet_18.model'),
('Image Classification', 'ResNet34_ImageNet_CNTK', 'https://www.cntk.ai/Models/CNTK_Pretrained/ResNet34_ImageNet_CNTK.model'),
('Image Classification', 'ResNet50_ImageNet_CNTK', 'https://www.cntk.ai/Models/CNTK_Pretrained/ResNet50_ImageNet_CNTK.model'),
('Image Classification', 'ResNet101_ImageNet_CNTK', 'https://www.cntk.ai/Models/CNTK_Pretrained/ResNet101_ImageNet_CNTK.model'),
('Image Classification', 'ResNet152_ImageNet_CNTK', 'https://www.cntk.ai/Models/CNTK_Pretrained/ResNet152_ImageNet_CNTK.model'),
('Image Classification', 'ResNet20_CIFAR10_CNTK', 'https://www.cntk.ai/Models/CNTK_Pretrained/ResNet20_CIFAR10_CNTK.model'),
('Image Classification', 'ResNet110_CIFAR10_CNTK', 'https://www.cntk.ai/Models/CNTK_Pretrained/ResNet110_CIFAR10_CNTK.model'),
('Image Classification', 'ResNet50_ImageNet_Caffe', 'https://www.cntk.ai/Models/Caffe_Converted/ResNet50_ImageNet_Caffe.model'),
('Image Classification', 'ResNet101_ImageNet_Caffe', 'https://www.cntk.ai/Models/Caffe_Converted/ResNet101_ImageNet_Caffe.model'),
('Image Classification', 'ResNet152_ImageNet_Caffe', 'https://www.cntk.ai/Models/Caffe_Converted/ResNet152_ImageNet_Caffe.model'),
('Image Classification', 'VGG16_ImageNet_Caffe', 'https://www.cntk.ai/Models/Caffe_Converted/VGG16_ImageNet_Caffe.model'),
('Image Classification', 'VGG19_ImageNet_Caffe', 'https://www.cntk.ai/Models/Caffe_Converted/VGG19_ImageNet_Caffe.model'),
('Image Object Detection', 'Fast-RCNN_grocery100', 'https://www.cntk.ai/Models/FRCN_Grocery/Fast-RCNN_grocery100.model'),
('Image Object Detection', 'Fast-RCNN_Pascal', 'https://www.cntk.ai/Models/FRCN_Pascal/Fast-RCNN.model'))
def download_model(model_file_name, model_url):
model_dir = os.path.dirname(os.path.abspath(__file__))
filename = os.path.join(model_dir, model_file_name)
if not os.path.exists(filename):
print('Downloading model from ' + model_url + ', may take a while...')
urlretrieve(model_url, filename)
print('Saved model as ' + filename)
else:
print('CNTK model already available at ' + filename)
def download_model_by_name(model_name):
if model_name.endswith('.model'):
model_name = model_name[:-6]
model = next((x for x in models if x[1]==model_name), None)
if model is None:
print("ERROR: Unknown model name '%s'." % model_name)
list_available_models()
else:
download_model(model_name + '.model', model[2])
def list_available_models():
print("\nAvailable models (for more information see Readme.md):")
max_cat = max(len(x[1]) for x in models)
max_name = max(len(x[1]) for x in models)
print("{:<{width}} {}".format('Model name', 'Category', width=max_name))
print("{:-<{width}} {:-<{width_cat}}".format('', '', width=max_name, width_cat=max_cat))
for model in sorted(models):
print("{:<{width}} {}".format(model[1], model[0], width=max_name))
if __name__ == "__main__":
args = sys.argv
if len(args) != 2:
print("Please provide a model name as the single argument. Usage:")
print(" python download_model.py <model_name>")
list_available_models()
else:
model_name = args[1]
if model_name == 'list':
list_available_models()
else:
download_model_by_name(model_name)