зеркало из https://github.com/microsoft/archai.git
publish new quantizer docker image for snpe-1.64.0.3605 (#97)
* fix: snpe output folder cleanup script, more general. * fix typo * feat: give status.py an optional --name argument. * switch to snpe-1.64.0.3605 * publish new docker image for quantization * find snpe_target_arch dynamically. * fix bug * fix: make it possible for reset.py to reset everything. * fix: remove unnecessary diagnostic output.
This commit is contained in:
Родитель
eea4be4fc6
Коммит
03edcdc0b1
|
@ -0,0 +1,42 @@
|
|||
import argparse
|
||||
import sys
|
||||
import tqdm
|
||||
from status import get_all_status_entities, get_status_table_service, update_status_entity
|
||||
from usage import get_all_usage_entities, get_usage_table_service, update_usage_entity
|
||||
|
||||
CONNECTION_NAME = 'MODEL_STORAGE_CONNECTION_STRING'
|
||||
STATUS_TABLE_NAME = 'STATUS_TABLE_NAME'
|
||||
|
||||
STATUS_TABLE = 'status'
|
||||
CONNECTION_STRING = ''
|
||||
|
||||
# the blobs can be easily copied using azcopy, see https://docs.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-blobs-copy
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Backup the status and usage tables to a new azure storage account ' +
|
||||
f'{CONNECTION_NAME} environment variable.')
|
||||
parser.add_argument('--target', help='Connection string for the target storage account.')
|
||||
args = parser.parse_args()
|
||||
if not args.target:
|
||||
print("Please provide --target connection string")
|
||||
sys.exit(1)
|
||||
|
||||
entities = get_all_status_entities()
|
||||
|
||||
target = get_status_table_service(args.target)
|
||||
|
||||
print(f"Uploading {len(entities)} status entities...")
|
||||
# upload the entities to the new service.
|
||||
with tqdm.tqdm(total=len(entities)) as pbar:
|
||||
for e in entities:
|
||||
update_status_entity(e, target)
|
||||
pbar.update(1)
|
||||
|
||||
usage = get_all_usage_entities()
|
||||
print(f"Uploading {len(usage)} usage entities...")
|
||||
target = get_usage_table_service(args.target)
|
||||
# upload the usage to the new service.
|
||||
with tqdm.tqdm(total=len(usage)) as pbar:
|
||||
for u in usage:
|
||||
update_usage_entity(u, target)
|
||||
pbar.update(1)
|
|
@ -3,7 +3,7 @@
|
|||
import os
|
||||
import json
|
||||
import sys
|
||||
from status import get_all_status_entities, update_status_entity, get_status_table_service
|
||||
from status import get_all_status_entities, update_status_entity, get_status_table_service, get_connection_string
|
||||
from azure.storage.blob import ContainerClient
|
||||
|
||||
CONNECTION_NAME = 'MODEL_STORAGE_CONNECTION_STRING'
|
||||
|
@ -25,7 +25,7 @@ def get_last_modified_date(e, blob_name):
|
|||
return None
|
||||
|
||||
|
||||
service = get_status_table_service()
|
||||
service = get_status_table_service(get_connection_string())
|
||||
|
||||
|
||||
# fix the 'complete' status...
|
||||
|
|
|
@ -25,7 +25,6 @@ def delete_blobs(friendly_name, specific_file=None):
|
|||
if specific_file and file_name != specific_file:
|
||||
continue
|
||||
|
||||
print("Deleting blob: " + file_name)
|
||||
container.delete_blob(blob)
|
||||
|
||||
|
||||
|
|
|
@ -3,7 +3,8 @@
|
|||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from status import get_status, update_status_entity
|
||||
import tqdm
|
||||
from status import get_existing_status, update_status_entity, get_all_status_entities, get_connection_string, get_status_table_service
|
||||
from delete import delete_blobs
|
||||
|
||||
CONNECTION_NAME = 'MODEL_STORAGE_CONNECTION_STRING'
|
||||
|
@ -14,22 +15,25 @@ def reset_metrics(entity, f1, ifs, macs):
|
|||
if f1:
|
||||
for key in ['f1_1k', 'f1_10k', 'f1_1k_f', 'f1_onnx']:
|
||||
if key in entity:
|
||||
print(f"Resetting '{key}'")
|
||||
del entity[key]
|
||||
if ifs:
|
||||
if "mean" in entity:
|
||||
del entity["mean"]
|
||||
if "stdev" in entity:
|
||||
del entity["stdev"]
|
||||
if "total_inference_avg" in entity:
|
||||
del entity["total_inference_avg"]
|
||||
if macs and "macs" in entity:
|
||||
del entity["macs"]
|
||||
if macs and "params" in entity:
|
||||
del entity["params"]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Reset the status of the given model to "new" so it is re-tested, " +\
|
||||
f"using your {CONNECTION_NAME} environment variable.')
|
||||
parser.add_argument('name', help='Friendly name of model status to reset')
|
||||
parser.add_argument('name', help='Friendly name of model status to reset (or * to reset everything!)')
|
||||
parser.add_argument('--all', help='Reset all properties', action="store_true")
|
||||
parser.add_argument('--f1', help='Reset f1 score', action="store_true")
|
||||
parser.add_argument('--ifs', help='Reset total_inference_avg', action="store_true")
|
||||
|
@ -55,11 +59,32 @@ if __name__ == '__main__':
|
|||
print(f"Please specify your {CONNECTION_NAME} environment variable.")
|
||||
sys.exit(1)
|
||||
|
||||
entity = get_status(friendly_name)
|
||||
reset_metrics(entity, f1, ifs, macs)
|
||||
entity['status'] = 'reset'
|
||||
update_status_entity(entity)
|
||||
entities = []
|
||||
|
||||
if quant:
|
||||
delete_blobs(friendly_name, 'model.dlc')
|
||||
delete_blobs(friendly_name, 'model.quant.dlc')
|
||||
service = get_status_table_service(get_connection_string())
|
||||
|
||||
if friendly_name == '*':
|
||||
a = input("Are you sure you want to reset everything (y o n)? ").strip().lower()
|
||||
if a != 'y' and a != 'yes':
|
||||
sys.exit(1)
|
||||
entities = get_all_status_entities()
|
||||
|
||||
else:
|
||||
entity = get_existing_status(friendly_name)
|
||||
if not entity:
|
||||
print(f"Entity {friendly_name} not found")
|
||||
sys.exit(1)
|
||||
entities += [entity]
|
||||
|
||||
with tqdm.tqdm(total=len(entities)) as pbar:
|
||||
for e in entities:
|
||||
name = e['name']
|
||||
reset_metrics(e, f1, ifs, macs)
|
||||
|
||||
if quant:
|
||||
delete_blobs(name, 'model.dlc')
|
||||
delete_blobs(name, 'model.quant.dlc')
|
||||
|
||||
e['status'] = 'reset'
|
||||
update_status_entity(e, service)
|
||||
pbar.update(1)
|
||||
|
|
|
@ -464,7 +464,7 @@ def run_model(name, snpe_root, dataset, conn_string, use_device, benchmark_only)
|
|||
merge_status_entity(entity)
|
||||
test_input = os.path.join('data', 'test')
|
||||
start = get_utc_date()
|
||||
run_batches(filename, test_input, snpe_output_dir)
|
||||
run_batches(filename, snpe_root, test_input, snpe_output_dir)
|
||||
end = get_utc_date()
|
||||
add_usage(get_device(), start, end)
|
||||
|
||||
|
@ -548,12 +548,12 @@ def find_work_prioritized(use_device, benchmark_only, subset_list, no_quantizati
|
|||
elif not is_complete(entity, 'f1_onnx'):
|
||||
priority = 60
|
||||
elif use_device and not is_complete(entity, 'f1_1k'):
|
||||
priority = get_mean_benchmark(entity)
|
||||
priority = 100 + get_mean_benchmark(entity)
|
||||
elif use_device and not is_complete(entity, 'f1_1k_f'):
|
||||
priority = get_mean_benchmark(entity) * 10
|
||||
priority = 100 + get_mean_benchmark(entity) * 10
|
||||
elif use_device and not is_complete(entity, 'f1_10k'):
|
||||
# prioritize by how fast the model is!
|
||||
priority = get_mean_benchmark(entity) * 100
|
||||
priority = 100 + get_mean_benchmark(entity) * 100
|
||||
else:
|
||||
# this model is done!
|
||||
continue
|
||||
|
|
|
@ -43,8 +43,7 @@ def get_connection_string():
|
|||
return CONNECTION_STRING
|
||||
|
||||
|
||||
def get_status_table_service():
|
||||
conn_str = get_connection_string()
|
||||
def get_status_table_service(conn_str):
|
||||
logger = logging.getLogger('azure.core.pipeline.policies.http_logging_policy')
|
||||
logger.setLevel(logging.ERROR)
|
||||
return TableServiceClient.from_connection_string(conn_str=conn_str, logger=logger, logging_enable=False)
|
||||
|
@ -58,7 +57,7 @@ def get_all_status_entities(status=None, not_equal=False, service=None):
|
|||
"""
|
||||
global STATUS_TABLE
|
||||
if not service:
|
||||
service = get_status_table_service()
|
||||
service = get_status_table_service(get_connection_string())
|
||||
table_client = service.create_table_if_not_exists(STATUS_TABLE)
|
||||
|
||||
entities = []
|
||||
|
@ -83,7 +82,7 @@ def get_all_status_entities(status=None, not_equal=False, service=None):
|
|||
def get_status(name, service=None):
|
||||
global STATUS_TABLE
|
||||
if not service:
|
||||
service = get_status_table_service()
|
||||
service = get_status_table_service(get_connection_string())
|
||||
table_client = service.create_table_if_not_exists(STATUS_TABLE)
|
||||
|
||||
try:
|
||||
|
@ -98,10 +97,22 @@ def get_status(name, service=None):
|
|||
return entity
|
||||
|
||||
|
||||
def get_existing_status(name, service=None):
|
||||
global STATUS_TABLE
|
||||
if not service:
|
||||
service = get_status_table_service(get_connection_string())
|
||||
table_client = service.create_table_if_not_exists(STATUS_TABLE)
|
||||
|
||||
try:
|
||||
return table_client.get_entity(partition_key='main', row_key=name)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def update_status_entity(entity, service=None):
|
||||
global STATUS_TABLE
|
||||
if not service:
|
||||
service = get_status_table_service()
|
||||
service = get_status_table_service(get_connection_string())
|
||||
table_client = service.create_table_if_not_exists(STATUS_TABLE)
|
||||
table_client.upsert_entity(entity=entity, mode=UpdateMode.REPLACE)
|
||||
|
||||
|
@ -109,7 +120,7 @@ def update_status_entity(entity, service=None):
|
|||
def merge_status_entity(entity, service=None):
|
||||
global STATUS_TABLE
|
||||
if not service:
|
||||
service = get_status_table_service()
|
||||
service = get_status_table_service(get_connection_string())
|
||||
table_client = service.create_table_if_not_exists(STATUS_TABLE)
|
||||
table_client.update_entity(entity=entity, mode=UpdateMode.MERGE)
|
||||
|
||||
|
@ -117,7 +128,7 @@ def merge_status_entity(entity, service=None):
|
|||
def update_status(name, status, priority=None, service=None):
|
||||
global STATUS_TABLE
|
||||
if not service:
|
||||
service = get_status_table_service()
|
||||
service = get_status_table_service(get_connection_string())
|
||||
table_client = service.create_table_if_not_exists(STATUS_TABLE)
|
||||
|
||||
try:
|
||||
|
@ -140,7 +151,7 @@ def update_status(name, status, priority=None, service=None):
|
|||
def delete_status(name, service=None):
|
||||
global STATUS_TABLE
|
||||
if not service:
|
||||
service = get_status_table_service()
|
||||
service = get_status_table_service(get_connection_string())
|
||||
table_client = service.create_table_if_not_exists(STATUS_TABLE)
|
||||
|
||||
for e in get_all_status_entities():
|
||||
|
|
|
@ -4,7 +4,7 @@ import argparse
|
|||
import os
|
||||
import time
|
||||
from runner import lock_job, unlock_job, get_unique_node_id, set_unique_node_id
|
||||
from status import update_status, get_status_table_service, get_status
|
||||
from status import update_status, get_status_table_service, get_status, get_connection_string
|
||||
|
||||
|
||||
def get_lock(entity):
|
||||
|
@ -43,7 +43,7 @@ if __name__ == '__main__':
|
|||
set_unique_node_id(get_unique_node_id() + f'_{os.getpid()}')
|
||||
name = args.name
|
||||
delay = args.delay
|
||||
service = get_status_table_service()
|
||||
service = get_status_table_service(get_connection_string())
|
||||
entity = update_status(name, 'testing', service=service)
|
||||
|
||||
count = 100
|
||||
|
|
|
@ -44,8 +44,7 @@ def get_connection_string():
|
|||
return CONNECTION_STRING
|
||||
|
||||
|
||||
def get_usage_table_service():
|
||||
conn_str = get_connection_string()
|
||||
def get_usage_table_service(conn_str):
|
||||
logger = logging.getLogger('azure.core.pipeline.policies.http_logging_policy')
|
||||
logger.setLevel(logging.ERROR)
|
||||
return TableServiceClient.from_connection_string(conn_str=conn_str, logger=logger, logging_enable=False)
|
||||
|
@ -55,7 +54,7 @@ def get_all_usage_entities(name_filter=None, service=None):
|
|||
""" Get all usage entities with optional device name filter """
|
||||
global USAGE_TABLE
|
||||
if not service:
|
||||
service = get_usage_table_service()
|
||||
service = get_usage_table_service(get_connection_string())
|
||||
table_client = service.create_table_if_not_exists(USAGE_TABLE)
|
||||
|
||||
entities = []
|
||||
|
@ -76,7 +75,7 @@ def get_all_usage_entities(name_filter=None, service=None):
|
|||
def update_usage_entity(entity, service=None):
|
||||
global USAGE_TABLE
|
||||
if not service:
|
||||
service = get_usage_table_service()
|
||||
service = get_usage_table_service(get_connection_string())
|
||||
table_client = service.create_table_if_not_exists(USAGE_TABLE)
|
||||
table_client.upsert_entity(entity=entity, mode=UpdateMode.REPLACE)
|
||||
|
||||
|
@ -84,7 +83,7 @@ def update_usage_entity(entity, service=None):
|
|||
def add_usage(name, start, end, service=None):
|
||||
global USAGE_TABLE
|
||||
if not service:
|
||||
service = get_usage_table_service()
|
||||
service = get_usage_table_service(get_connection_string())
|
||||
service.create_table_if_not_exists(USAGE_TABLE)
|
||||
|
||||
entity = {
|
||||
|
|
|
@ -54,7 +54,7 @@ RUN wget -O azcopy_v10.tar.gz https://aka.ms/downloadazcopy-v10-linux && tar -xf
|
|||
|
||||
# this echo is a trick to bypass docker build cache.
|
||||
# simply change the echo string every time you want docker build to pull down new bits.
|
||||
RUN echo '07/14/2022 16:35 AM' >/dev/null && git clone "https://github.com/microsoft/archai.git"
|
||||
RUN echo '07/26/2022 04:36 PM' >/dev/null && git clone "https://github.com/microsoft/archai.git"
|
||||
|
||||
RUN source /home/archai/.profile && \
|
||||
pushd /home/archai/archai/devices && \
|
||||
|
|
|
@ -14,7 +14,7 @@ spec:
|
|||
spec:
|
||||
containers:
|
||||
- name: snpe-quantizer
|
||||
image: snpecontainerregistry001.azurecr.io/quantizer:1.15
|
||||
image: snpecontainerregistry001.azurecr.io/quantizer:1.16
|
||||
resources:
|
||||
limits:
|
||||
cpu: 4
|
||||
|
|
|
@ -18,7 +18,7 @@ AKS.
|
|||
|
||||
The setup script requires the following environment variables be set before hand:
|
||||
|
||||
- **SNPE_SDK** - points to a local zip file containing SNPE SKK version `snpe-1.61.0.zip`
|
||||
- **SNPE_SDK** - points to a local zip file containing SNPE SDK version `snpe-1.64.0_3605.zip`
|
||||
- **ANDROID_NDK** - points to a local zip file containing the Android NDK zip file version `android-ndk-r23b-linux.zip`
|
||||
- **INPUT_TESTSET** - points to a local zip file containing 10,000 image test set from your dataset.
|
||||
|
||||
|
|
|
@ -70,12 +70,12 @@ trained ONNX model into this folder. You should have something like:
|
|||
|
||||
1. **Setup your snpe environment**. For onnx toolset use the following:
|
||||
```
|
||||
pushd ~/snpe/snpe-1.60.0.3313
|
||||
pushd ~/snpe/snpe-1.64.0.3605
|
||||
source bin/envsetup.sh -o ~/anaconda3/envs/snap/lib/python3.6/site-packages/onnx
|
||||
```
|
||||
For tensorflow use:
|
||||
```
|
||||
pushd ~/snpe/snpe-1.60.0.3313
|
||||
pushd ~/snpe/snpe-1.64.0.3605
|
||||
source bin/envsetup.sh -ot ~/anaconda3/envs/snap/lib/python3.6/site-packages/tensorflow
|
||||
```
|
||||
|
||||
|
|
|
@ -28,12 +28,12 @@ DEVICE = None
|
|||
# device parameters
|
||||
# the /data mount on the device has 64GB available.
|
||||
DEVICE_WORKING_DIR = "/data/local/tmp"
|
||||
SNPE_TARGET_ARCH = "aarch64-android-clang8.0"
|
||||
SNPE_TARGET_STL = "libgnustl_shared.so"
|
||||
SNPE_BENCH = None
|
||||
RANDOM_INPUTS = 'random_inputs'
|
||||
RANDOM_INPUT_LIST = 'random_raw_list.txt'
|
||||
|
||||
SNPE_ROOT = None
|
||||
snpe_target_arch = None
|
||||
|
||||
def set_device(device):
|
||||
global DEVICE
|
||||
|
@ -219,21 +219,35 @@ def download_results(input_images, start, output_dir):
|
|||
sys.exit(1)
|
||||
|
||||
|
||||
def setup_libs(snpe_root):
|
||||
def get_target_arch(snpe_root):
|
||||
global SNPE_ROOT
|
||||
SNPE_ROOT = snpe_root
|
||||
if not os.path.isdir(snpe_root):
|
||||
print("SNPE_ROOT folder {} not found".format(snpe_root))
|
||||
sys.exit(1)
|
||||
for name in os.listdir(os.path.join(snpe_root, 'lib')):
|
||||
if name.startswith('aarch64-android'):
|
||||
print(f"Using SNPE_TARGET_ARCH {name}")
|
||||
return name
|
||||
|
||||
print("SNPE_ROOT folder {} missing aarch64-android-*".format(snpe_root))
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def setup_libs(snpe_root):
|
||||
global snpe_target_arch
|
||||
snpe_target_arch = get_target_arch(snpe_root)
|
||||
|
||||
print("Pushing SNPE binaries and libraries to device...")
|
||||
shell = Shell()
|
||||
for dir in [f"{DEVICE_WORKING_DIR}/{SNPE_TARGET_ARCH}/bin",
|
||||
f"{DEVICE_WORKING_DIR}/{SNPE_TARGET_ARCH}/lib",
|
||||
for dir in [f"{DEVICE_WORKING_DIR}/{snpe_target_arch}/bin",
|
||||
f"{DEVICE_WORKING_DIR}/{snpe_target_arch}/lib",
|
||||
f"{DEVICE_WORKING_DIR}/dsp/lib"]:
|
||||
shell.run(os.getcwd(), adb(f"shell \"mkdir -p {dir}\""))
|
||||
|
||||
shell.run(
|
||||
os.path.join(snpe_root, "lib", SNPE_TARGET_ARCH),
|
||||
adb(f"push . {DEVICE_WORKING_DIR}/{SNPE_TARGET_ARCH}/lib"), VERBOSE)
|
||||
os.path.join(snpe_root, "lib", snpe_target_arch),
|
||||
adb(f"push . {DEVICE_WORKING_DIR}/{snpe_target_arch}/lib"), VERBOSE)
|
||||
|
||||
shell.run(
|
||||
os.path.join(snpe_root, "lib", 'dsp'),
|
||||
|
@ -241,12 +255,12 @@ def setup_libs(snpe_root):
|
|||
|
||||
for program in ['snpe-net-run', 'snpe-parallel-run', 'snpe-platform-validator', 'snpe-throughput-net-run']:
|
||||
shell.run(
|
||||
os.path.join(snpe_root, "bin", SNPE_TARGET_ARCH),
|
||||
adb(f"push {program} {DEVICE_WORKING_DIR}/{SNPE_TARGET_ARCH}/bin"), VERBOSE)
|
||||
os.path.join(snpe_root, "bin", snpe_target_arch),
|
||||
adb(f"push {program} {DEVICE_WORKING_DIR}/{snpe_target_arch}/bin"), VERBOSE)
|
||||
|
||||
shell.run(
|
||||
os.path.join(snpe_root, "bin", SNPE_TARGET_ARCH),
|
||||
adb(f"shell \"chmod u+x {DEVICE_WORKING_DIR}/{SNPE_TARGET_ARCH}/bin/{program}\""))
|
||||
os.path.join(snpe_root, "bin", snpe_target_arch),
|
||||
adb(f"shell \"chmod u+x {DEVICE_WORKING_DIR}/{snpe_target_arch}/bin/{program}\""))
|
||||
|
||||
|
||||
def clear_images():
|
||||
|
@ -309,10 +323,15 @@ def setup_model(model):
|
|||
|
||||
|
||||
def get_setup():
|
||||
global snpe_target_arch
|
||||
if not snpe_target_arch:
|
||||
print(f"snpe_target_arch is not set")
|
||||
sys.exit(1)
|
||||
|
||||
lib_path = f"{DEVICE_WORKING_DIR}/dsp/lib;/system/lib/rfsa/adsp;/system/vendor/lib/rfsa/adsp;/dsp"
|
||||
setup = f"export SNPE_TARGET_ARCH={SNPE_TARGET_ARCH} && " + \
|
||||
f"export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:{DEVICE_WORKING_DIR}/{SNPE_TARGET_ARCH}/lib && " + \
|
||||
f"export PATH=$PATH:{DEVICE_WORKING_DIR}/{SNPE_TARGET_ARCH}/bin && " + \
|
||||
setup = f"export SNPE_TARGET_ARCH={snpe_target_arch} && " + \
|
||||
f"export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:{DEVICE_WORKING_DIR}/{snpe_target_arch}/lib && " + \
|
||||
f"export PATH=$PATH:{DEVICE_WORKING_DIR}/{snpe_target_arch}/bin && " + \
|
||||
f"export ADSP_LIBRARY_PATH='{lib_path}' && " + \
|
||||
f"cd {DEVICE_WORKING_DIR}/{TASK}/{MODEL}"
|
||||
|
||||
|
@ -325,6 +344,12 @@ def run_test(model):
|
|||
if not model:
|
||||
print("### --run needs the --model parameter")
|
||||
sys.exit(1)
|
||||
|
||||
global snpe_target_arch
|
||||
if not snpe_target_arch:
|
||||
print(f"snpe_target_arch is not set")
|
||||
sys.exit(1)
|
||||
|
||||
shell = Shell()
|
||||
# make sure any previous run output is cleared.
|
||||
shell.run(os.getcwd(), adb(f'shell \"rm -rf {DEVICE_WORKING_DIR}/{TASK}/{MODEL}/output\"'))
|
||||
|
@ -333,7 +358,7 @@ def run_test(model):
|
|||
setup = get_setup()
|
||||
shell.run(
|
||||
os.getcwd(),
|
||||
adb(f"shell \"export SNPE_TARGET_ARCH={SNPE_TARGET_ARCH} && {setup} &&" +
|
||||
adb(f"shell \"export SNPE_TARGET_ARCH={snpe_target_arch} && {setup} &&" +
|
||||
f"snpe-net-run --container ./{model} --input_list ../data/test/input_list_for_device.txt {use_dsp}\""))
|
||||
|
||||
|
||||
|
@ -423,6 +448,9 @@ def run_benchmark(model, name, shape, snpe_root, iterations, random_input_count)
|
|||
print(f"The --snpe {snpe_root} not found")
|
||||
sys.exit(1)
|
||||
|
||||
global snpe_target_arch
|
||||
snpe_target_arch = get_target_arch(snpe_root)
|
||||
|
||||
cwd = os.getcwd()
|
||||
benchmark_dir = os.path.join(cwd, name, 'benchmark')
|
||||
if os.path.isdir(benchmark_dir):
|
||||
|
@ -515,7 +543,11 @@ def compute_results(shape):
|
|||
return get_metrics(image_size, False, dataset, output_dir)
|
||||
|
||||
|
||||
def run_batches(model, images, output_dir):
|
||||
def run_batches(model, snpe_root, images, output_dir):
|
||||
|
||||
global snpe_target_arch
|
||||
snpe_target_arch = get_target_arch(snpe_root)
|
||||
|
||||
files = [x for x in os.listdir(images) if x.endswith(".bin")]
|
||||
files.sort()
|
||||
|
||||
|
@ -578,6 +610,9 @@ if __name__ == '__main__':
|
|||
snpe = os.getenv("SNPE_ROOT")
|
||||
if not snpe:
|
||||
print("please set your SNPE_ROOT environment variable, see readme.md")
|
||||
sys.exit(1)
|
||||
|
||||
snpe_target_arch = get_target_arch(snpe)
|
||||
|
||||
if snpe:
|
||||
sys.path += [f'{snpe}/benchmarks', f'{snpe}/lib/python']
|
||||
|
@ -600,5 +635,5 @@ if __name__ == '__main__':
|
|||
sys.exit(0)
|
||||
|
||||
if args.images:
|
||||
run_batches(model, args.images, OUTPUT_DIR)
|
||||
run_batches(model, snpe, args.images, OUTPUT_DIR)
|
||||
compute_results(shape)
|
||||
|
|
Загрузка…
Ссылка в новой задаче