Fix shape in PyTorch Fusion Rule Test (#81)

This commit is contained in:
Jiahang Xu 2022-09-13 18:38:57 +08:00 коммит произвёл GitHub
Родитель 7ac63167f4
Коммит a700e6d293
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 13 добавлений и 8 удалений

Просмотреть файл

@ -79,7 +79,10 @@ def generate_testcases():
if op1 in d1_required_layers or op2 in d1_required_layers:
input_shape = [config['SHAPE_1D']]
else:
input_shape = [config['HW'], config['HW'], config['CIN']]
if implement == "tensorflow":
input_shape = [config['HW'], config['HW'], config['CIN']]
else:
input_shape = [config['CIN'], config['HW'], config['HW']]
bf_cls = type(class_name, (BasicFusion,), {
'name': name,
'cases': cases,

Просмотреть файл

@ -6,6 +6,7 @@ import logging
from ..utils import read_profiled_results
from nn_meter.builder.utils import merge_info
from nn_meter.builder.backend_meta.utils import Latency
logging = logging.getLogger("nn-Meter")
class BaseTestCase:

Просмотреть файл

@ -53,6 +53,6 @@ class FusionRuleTester:
latency = {key: str(value) for key, value in rule.latency.items()}
result[name]['latency'] = latency
result[name]['obey'] = obey
result[name]['obey'] = bool(obey)
return result

Просмотреть файл

@ -17,8 +17,9 @@ def copy_to_workspace(backend_type, workspace_path, backendConfigFile = None):
os.makedirs(os.path.join(workspace_path, 'configs'), exist_ok=True)
# backend config
if backend_type == 'customized' and backendConfigFile:
copyfile(backendConfigFile, os.path.join(workspace_path, 'configs', 'backend_config.yaml'))
if backend_type == 'customized':
if backendConfigFile:
copyfile(backendConfigFile, os.path.join(workspace_path, 'configs', 'backend_config.yaml'))
else:
if backend_type == 'tflite':
config_name = __backend_tflite_cfg_filename__

Просмотреть файл

@ -128,13 +128,13 @@ class SE(BaseOperator):
class FC(BaseOperator):
def get_model(self):
cin = self.input_shape[0]
cout = self.input_shape[0] if "COUT" not in self.config else self.config["COUT"]
cin = self.input_shape[-1]
cout = self.input_shape[-1] if "COUT" not in self.config else self.config["COUT"]
return nn.Linear(cin, cout)
def get_output_shape(self):
cout = self.input_shape[0] if "COUT" not in self.config else self.config["COUT"]
return [cout] + self.input_shape[1:]
cout = self.input_shape[-1] if "COUT" not in self.config else self.config["COUT"]
return self.input_shape[:-1] + [cout]
#-------------------- activation function --------------------#