refine package import
This commit is contained in:
Родитель
ac83a2bffa
Коммит
72eee00c08
|
@ -7,20 +7,26 @@ try:
|
|||
except ModuleNotFoundError:
|
||||
__version__ = 'UNKNOWN'
|
||||
|
||||
from .nn_meter import (
|
||||
nnMeter,
|
||||
import logging
|
||||
from functools import partial, partialmethod
|
||||
|
||||
from .predictor import (
|
||||
nnMeterPredictor,
|
||||
load_latency_predictor,
|
||||
list_latency_predictors,
|
||||
latency_metrics
|
||||
)
|
||||
from .ir_converter import (
|
||||
model_file_to_graph,
|
||||
model_to_graph,
|
||||
model_to_graph
|
||||
)
|
||||
from .utils import (
|
||||
create_user_configs,
|
||||
change_user_data_folder
|
||||
)
|
||||
from .utils.utils import download_from_url
|
||||
from .predictor import latency_metrics
|
||||
from .dataset import bench_dataset # TODO: add GNNDataloader and GNNDataset here @wenxuan
|
||||
import logging
|
||||
from functools import partial, partialmethod
|
||||
from .dataset import bench_dataset
|
||||
from .utils import download_from_url
|
||||
|
||||
|
||||
logging.KEYINFO = 22
|
||||
logging.addLevelName(logging.KEYINFO, 'KEYINFO')
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
import os, sys
|
||||
from nn_meter.predictor import latency_metrics
|
||||
import logging
|
||||
import jsonlines
|
||||
from glob import glob
|
||||
|
||||
from nn_meter.nn_meter import list_latency_predictors, load_latency_predictor, get_user_data_folder
|
||||
from nn_meter import download_from_url
|
||||
import jsonlines
|
||||
import logging
|
||||
from nn_meter.predictor import latency_metrics, list_latency_predictors, load_latency_predictor
|
||||
from nn_meter.utils import download_from_url, get_user_data_folder
|
||||
|
||||
|
||||
__user_dataset_folder__ = os.path.join(get_user_data_folder(), 'dataset')
|
||||
|
|
|
@ -1,16 +1,18 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
import torch
|
||||
import jsonlines
|
||||
import os
|
||||
import random
|
||||
import torch
|
||||
import jsonlines
|
||||
from .bench_dataset import bench_dataset
|
||||
from nn_meter.nn_meter import get_user_data_folder
|
||||
from nn_meter.utils.utils import try_import_dgl
|
||||
from nn_meter.utils import get_user_data_folder
|
||||
from nn_meter.utils.import_package import try_import_dgl
|
||||
|
||||
|
||||
RAW_DATA_URL = "https://github.com/microsoft/nn-Meter/releases/download/v1.0-data/datasets.zip"
|
||||
__user_dataset_folder__ = os.path.join(get_user_data_folder(), 'dataset')
|
||||
|
||||
|
||||
hws = [
|
||||
"cortexA76cpu_tflite21",
|
||||
"adreno640gpu_tflite21",
|
||||
|
|
|
@ -2,10 +2,10 @@
|
|||
# Licensed under the MIT license.
|
||||
import numpy as np
|
||||
|
||||
from nn_meter.utils.graph_tool import ModelGraph
|
||||
from .frozenpb_parser import FrozenPbParser
|
||||
from .shape_inference import ShapeInference
|
||||
from .shape_fetcher import ShapeFetcher
|
||||
from nn_meter.utils.graph_tool import ModelGraph
|
||||
|
||||
class FrozenPbConverter:
|
||||
def __init__(self, file_name):
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
from nn_meter.utils.utils import try_import_tensorflow
|
||||
from .protobuf_helper import ProtobufHelper
|
||||
from .shape_fetcher import ShapeFetcher
|
||||
import copy
|
||||
import re
|
||||
import copy
|
||||
import logging
|
||||
|
||||
from .protobuf_helper import ProtobufHelper
|
||||
from nn_meter.utils.import_package import try_import_tensorflow
|
||||
|
||||
|
||||
logging = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
from nn_meter.utils.utils import try_import_tensorflow
|
||||
import numpy as np
|
||||
from typing import List
|
||||
|
||||
from nn_meter.utils.utils import try_import_tensorflow
|
||||
|
||||
class ShapeFetcher:
|
||||
def __init__(self, input_graph):
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
from .protobuf_helper import ProtobufHelper as ph
|
||||
from functools import reduce
|
||||
import copy
|
||||
import math
|
||||
import logging
|
||||
from .protobuf_helper import ProtobufHelper as ph
|
||||
|
||||
|
||||
logging = logging.getLogger(__name__)
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
from nn_meter.utils.utils import try_import_onnx
|
||||
import logging
|
||||
import networkx as nx
|
||||
from itertools import chain
|
||||
from .utils import get_tensor_shape
|
||||
from .constants import SLICE_TYPE
|
||||
from itertools import chain
|
||||
import logging
|
||||
from nn_meter.utils.import_package import try_import_onnx
|
||||
|
||||
|
||||
class OnnxConverter:
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
|
||||
def get_tensor_shape(tensor):
|
||||
shape = []
|
||||
for dim in tensor.type.tensor_type.shape.dim:
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
from nn_meter.utils.utils import try_import_onnx, try_import_torch, try_import_onnxsim, try_import_nni
|
||||
import tempfile
|
||||
from nn_meter.ir_converter.onnx_converter import OnnxConverter
|
||||
|
||||
|
||||
from ..onnx_converter import OnnxConverter
|
||||
from .opset_map import nni_attr_map, nni_type_map
|
||||
from nn_meter.utils.import_package import try_import_onnx, try_import_torch, try_import_onnxsim, try_import_nni
|
||||
|
||||
|
||||
def _nchw_to_nhwc(shapes):
|
||||
|
|
|
@ -2,11 +2,12 @@
|
|||
# Licensed under the MIT license.
|
||||
import json
|
||||
import logging
|
||||
from nn_meter.utils.utils import try_import_onnx, try_import_torch, try_import_torchvision_models
|
||||
from nn_meter.utils.import_package import try_import_onnx, try_import_torch, try_import_torchvision_models
|
||||
from .onnx_converter import OnnxConverter
|
||||
from .frozenpb_converter import FrozenPbConverter
|
||||
from .torch_converter import NNIBasedTorchConverter, OnnxBasedTorchConverter, NNIIRConverter
|
||||
|
||||
|
||||
def model_file_to_graph(filename: str, model_type: str, input_shape=(1, 3, 224, 224), apply_nni=False):
|
||||
"""
|
||||
read the given file and convert the model in the file content to nn-Meter IR graph object
|
||||
|
@ -106,10 +107,12 @@ def onnx_model_to_graph(model):
|
|||
converter = OnnxConverter(model)
|
||||
return converter.convert()
|
||||
|
||||
|
||||
def nni_model_to_graph(model):
|
||||
converter = NNIIRConverter(model)
|
||||
return converter.convert()
|
||||
|
||||
|
||||
def torch_model_to_graph(model, input_shape=(1, 3, 224, 224), apply_nni=False):
|
||||
torch = try_import_torch()
|
||||
args = torch.randn(*input_shape)
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
from .detection.detector import KernelDetector
|
||||
from .kernel_detector import KernelDetector
|
||||
|
|
|
@ -1,2 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
|
@ -3,7 +3,7 @@
|
|||
import os
|
||||
import json
|
||||
from nn_meter.utils.graph_tool import ModelGraph
|
||||
from nn_meter.kernel_detector.utils.ir_tools import convert_nodes
|
||||
from ..utils.ir_tools import convert_nodes
|
||||
|
||||
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
from nn_meter.kernel_detector.rulelib.rule_reader import RuleReader
|
||||
from nn_meter.kernel_detector.rulelib.rule_splitter import RuleSplitter
|
||||
from nn_meter.utils.graph_tool import ModelGraph
|
||||
from nn_meter.kernel_detector.utils.constants import DUMMY_TYPES
|
||||
from nn_meter.kernel_detector.utils.ir_tools import convert_nodes
|
||||
# import logging
|
||||
from .utils.constants import DUMMY_TYPES
|
||||
from .utils.ir_tools import convert_nodes
|
||||
from .rulelib.rule_reader import RuleReader
|
||||
from .rulelib.rule_splitter import RuleSplitter
|
||||
|
||||
|
||||
class KernelDetector:
|
|
@ -1,8 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
import json
|
||||
from ..fusionlib import get_fusion_unit
|
||||
from nn_meter.utils.graph_tool import ModelGraph
|
||||
from nn_meter.kernel_detector.fusionlib import get_fusion_unit
|
||||
|
||||
|
||||
class RuleReader:
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
from .rule_reader import RuleReader
|
||||
from ..utils.match_helper import MatchHelper
|
||||
from ..utils.fusion_aware_graph import FusionAwareGraph
|
||||
from nn_meter.utils.graph_tool import ModelGraph
|
||||
from nn_meter.kernel_detector.utils.match_helper import MatchHelper
|
||||
from nn_meter.kernel_detector.utils.fusion_aware_graph import FusionAwareGraph
|
||||
|
||||
|
||||
class RuleSplitter:
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
from nn_meter.utils.graph_tool import ModelGraph
|
||||
from .union_find import UF
|
||||
import networkx as nx
|
||||
from .union_find import UF
|
||||
from nn_meter.utils.graph_tool import ModelGraph
|
||||
|
||||
|
||||
class FusionAwareGraph:
|
||||
|
|
|
@ -5,12 +5,7 @@ import os
|
|||
import sys
|
||||
import argparse
|
||||
import logging
|
||||
from nn_meter.nn_meter import *
|
||||
|
||||
__user_config_folder__ = os.path.expanduser('~/.nn_meter/config')
|
||||
__user_data_folder__ = os.path.expanduser('~/.nn_meter/data')
|
||||
|
||||
__predictors_cfg_filename__ = 'predictors.yaml'
|
||||
from nn_meter import *
|
||||
|
||||
|
||||
def list_latency_predictors_cli():
|
||||
|
@ -62,6 +57,7 @@ def apply_latency_predictor_cli(args):
|
|||
|
||||
return result
|
||||
|
||||
|
||||
def get_nnmeter_ir_cli(args):
|
||||
"""convert pb file or onnx file to nn-Meter IR graph according to the command line interface arguments
|
||||
"""
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
from .predictors.utils import latency_metrics
|
||||
|
||||
|
||||
from .prediction.utils import latency_metrics
|
||||
from .nn_meter_predictor import nnMeterPredictor, list_latency_predictors, load_latency_predictor
|
||||
|
|
|
@ -1,18 +1,18 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
from glob import glob
|
||||
from nn_meter.predictor.predictors.predict_by_kernel import nn_predict
|
||||
from nn_meter.kernel_detector import KernelDetector
|
||||
from nn_meter.ir_converter import model_file_to_graph, model_to_graph
|
||||
from .utils import loading_to_local
|
||||
|
||||
import yaml
|
||||
import os
|
||||
import yaml
|
||||
import pkg_resources
|
||||
from shutil import copyfile
|
||||
from packaging import version
|
||||
import logging
|
||||
|
||||
from .utils import loading_to_local
|
||||
from .prediction.predict_by_kernel import nn_predict
|
||||
from nn_meter.kernel_detector import KernelDetector
|
||||
from nn_meter.utils import load_config_file, get_user_data_folder
|
||||
from nn_meter.ir_converter import model_file_to_graph, model_to_graph
|
||||
|
||||
|
||||
__predictors_cfg_filename__ = 'predictors.yaml'
|
||||
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
import logging
|
||||
import numpy as np
|
||||
from sklearn.metrics import mean_squared_error
|
||||
import logging
|
||||
|
||||
|
||||
def get_flop(input_channel, output_channel, k, H, W, stride):
|
|
@ -1,9 +1,9 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import numpy as np
|
||||
from sklearn.metrics import mean_squared_error
|
||||
|
||||
|
||||
def get_kernel_name(optype):
|
||||
"""
|
||||
for many similar kernels, we use one kernel predictor since their latency difference is negligible,
|
||||
|
@ -34,6 +34,7 @@ def get_kernel_name(optype):
|
|||
|
||||
return optype
|
||||
|
||||
|
||||
def get_accuracy(y_pred, y_true, threshold=0.01):
|
||||
a = (y_true - y_pred) / y_true
|
||||
b = np.where(abs(a) <= threshold)
|
|
@ -7,7 +7,7 @@ from zipfile import ZipFile
|
|||
from tqdm import tqdm
|
||||
import requests
|
||||
import logging
|
||||
from nn_meter.utils.utils import download_from_url
|
||||
from nn_meter.utils import download_from_url
|
||||
|
||||
|
||||
def loading_to_local(pred_info, dir="data/predictorzoo"):
|
||||
|
|
|
@ -6,3 +6,4 @@ from config_manager import (
|
|||
change_user_data_folder,
|
||||
load_config_file
|
||||
)
|
||||
from utils import download_from_url
|
|
@ -1,14 +1,9 @@
|
|||
from glob import glob
|
||||
from nn_meter.predictor.predictors.predict_by_kernel import nn_predict
|
||||
from nn_meter.kernel_detector import KernelDetector
|
||||
from nn_meter.ir_converter import model_file_to_graph, model_to_graph
|
||||
from nn_meter.predictor.load_predictors import loading_to_local
|
||||
|
||||
import yaml
|
||||
import os
|
||||
import logging
|
||||
import pkg_resources
|
||||
from shutil import copyfile
|
||||
import logging
|
||||
|
||||
|
||||
__user_config_folder__ = os.path.expanduser('~/.nn_meter/config')
|
||||
__default_user_data_folder__ = os.path.expanduser('~/.nn_meter/data')
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
from packaging import version
|
||||
import logging
|
||||
from packaging import version
|
||||
|
||||
|
||||
def try_import_onnx(require_version = ["1.9.0"]):
|
||||
|
|
|
@ -20,7 +20,7 @@ def download_from_url(urladdr, ppath):
|
|||
if not os.path.isdir(ppath):
|
||||
os.makedirs(ppath)
|
||||
|
||||
# logging.keyinfo(f'Download from {urladdr}')
|
||||
logging.keyinfo(f'Download from {urladdr}')
|
||||
response = requests.get(urladdr, stream=True)
|
||||
total_size_in_bytes = int(response.headers.get("content-length", 0))
|
||||
block_size = 2048 # 2 Kibibyte
|
||||
|
|
Загрузка…
Ссылка в новой задаче