Fix shape in PyTorch Fusion Rule Test (#81)
This commit is contained in:
Родитель
7ac63167f4
Коммит
a700e6d293
|
@ -79,7 +79,10 @@ def generate_testcases():
|
|||
if op1 in d1_required_layers or op2 in d1_required_layers:
|
||||
input_shape = [config['SHAPE_1D']]
|
||||
else:
|
||||
input_shape = [config['HW'], config['HW'], config['CIN']]
|
||||
if implement == "tensorflow":
|
||||
input_shape = [config['HW'], config['HW'], config['CIN']]
|
||||
else:
|
||||
input_shape = [config['CIN'], config['HW'], config['HW']]
|
||||
bf_cls = type(class_name, (BasicFusion,), {
|
||||
'name': name,
|
||||
'cases': cases,
|
||||
|
|
|
@ -6,6 +6,7 @@ import logging
|
|||
from ..utils import read_profiled_results
|
||||
from nn_meter.builder.utils import merge_info
|
||||
from nn_meter.builder.backend_meta.utils import Latency
|
||||
logging = logging.getLogger("nn-Meter")
|
||||
|
||||
|
||||
class BaseTestCase:
|
||||
|
|
|
@ -53,6 +53,6 @@ class FusionRuleTester:
|
|||
latency = {key: str(value) for key, value in rule.latency.items()}
|
||||
result[name]['latency'] = latency
|
||||
|
||||
result[name]['obey'] = obey
|
||||
result[name]['obey'] = bool(obey)
|
||||
|
||||
return result
|
||||
|
|
|
@ -17,8 +17,9 @@ def copy_to_workspace(backend_type, workspace_path, backendConfigFile = None):
|
|||
os.makedirs(os.path.join(workspace_path, 'configs'), exist_ok=True)
|
||||
|
||||
# backend config
|
||||
if backend_type == 'customized' and backendConfigFile:
|
||||
copyfile(backendConfigFile, os.path.join(workspace_path, 'configs', 'backend_config.yaml'))
|
||||
if backend_type == 'customized':
|
||||
if backendConfigFile:
|
||||
copyfile(backendConfigFile, os.path.join(workspace_path, 'configs', 'backend_config.yaml'))
|
||||
else:
|
||||
if backend_type == 'tflite':
|
||||
config_name = __backend_tflite_cfg_filename__
|
||||
|
|
|
@ -128,13 +128,13 @@ class SE(BaseOperator):
|
|||
|
||||
class FC(BaseOperator):
|
||||
def get_model(self):
|
||||
cin = self.input_shape[0]
|
||||
cout = self.input_shape[0] if "COUT" not in self.config else self.config["COUT"]
|
||||
cin = self.input_shape[-1]
|
||||
cout = self.input_shape[-1] if "COUT" not in self.config else self.config["COUT"]
|
||||
return nn.Linear(cin, cout)
|
||||
|
||||
def get_output_shape(self):
|
||||
cout = self.input_shape[0] if "COUT" not in self.config else self.config["COUT"]
|
||||
return [cout] + self.input_shape[1:]
|
||||
cout = self.input_shape[-1] if "COUT" not in self.config else self.config["COUT"]
|
||||
return self.input_shape[:-1] + [cout]
|
||||
|
||||
#-------------------- activation function --------------------#
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче