decouple predictor config and overall config

This commit is contained in:
jiahangxu 2021-11-16 14:47:25 +08:00
Родитель 1323fe0887
Коммит fbe1545d1d
4 изменённых файлов: 24 добавлений и 25 удалений

Просмотреть файл

@ -4,10 +4,10 @@ import os
from packaging import version
import logging
from .utils import loading_to_local
from .utils import load_config_file, loading_to_local
from .prediction.predict_by_kernel import nn_predict
from nn_meter.kernel_detector import KernelDetector
from nn_meter.utils import load_config_file, get_user_data_folder
from nn_meter.utils import get_user_data_folder
from nn_meter.ir_converter import model_file_to_graph, model_to_graph

Просмотреть файл

@ -3,11 +3,11 @@
import pickle
import os
from glob import glob
from zipfile import ZipFile
from tqdm import tqdm
import requests
import logging
from nn_meter.utils import download_from_url
import yaml
from nn_meter.utils import download_from_url, create_user_configs
__user_config_folder__ = os.path.expanduser('~/.nn_meter/config')
def loading_to_local(pred_info, dir="data/predictorzoo"):
@ -63,3 +63,20 @@ def check_predictors(ppath, kernel_predictors):
return True
else:
return False
def load_config_file(fname: str, loader=None):
"""load config file from __user_config_folder__;
if the file not located in __user_config_folder__, copy it from distribution
"""
filepath = os.path.join(__user_config_folder__, fname)
try:
with open(filepath) as fp:
if loader is None:
return yaml.load(fp, yaml.FullLoader)
else:
return loader(fp)
except FileNotFoundError:
logging.info(f"config file {filepath} not found, created")
create_user_configs()
return load_config_file(fname)

Просмотреть файл

@ -3,7 +3,6 @@
from .config_manager import (
create_user_configs,
get_user_data_folder,
change_user_data_folder,
load_config_file
change_user_data_folder
)
from .utils import download_from_url

Просмотреть файл

@ -47,20 +47,3 @@ def change_user_data_folder(new_folder):
with open(os.path.join(__user_config_folder__, 'settings.yaml'), 'w') as fp:
setting['data_folder'] = new_folder
yaml.dump(setting, fp)
def load_config_file(fname: str, loader=None):
"""load config file from __user_config_folder__;
if the file not located in __user_config_folder__, copy it from distribution
"""
filepath = os.path.join(__user_config_folder__, fname)
try:
with open(filepath) as fp:
if loader is None:
return yaml.load(fp, yaml.FullLoader)
else:
return loader(fp)
except FileNotFoundError:
logging.info(f"config file {filepath} not found, created")
create_user_configs()
return load_config_file(fname)