This commit is contained in:
Patrick von Platen 2020-07-16 15:15:10 +02:00 коммит произвёл GitHub
Родитель aefc0c0429
Коммит 89a78be51f
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 38 добавлений и 4 удалений

Просмотреть файл

@ -88,7 +88,11 @@ class PyTorchBenchmark(Benchmark):
if self.args.torchscript: if self.args.torchscript:
config.torchscript = True config.torchscript = True
has_model_class_in_config = hasattr(config, "architectures") and len(config.architectures) > 0 has_model_class_in_config = (
hasattr(config, "architectures")
and isinstance(config.architectures, list)
and len(config.architectures) > 0
)
if not self.args.only_pretrain_model and has_model_class_in_config: if not self.args.only_pretrain_model and has_model_class_in_config:
try: try:
model_class = config.architectures[0] model_class = config.architectures[0]
@ -138,7 +142,11 @@ class PyTorchBenchmark(Benchmark):
def _prepare_train_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]: def _prepare_train_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
config = self.config_dict[model_name] config = self.config_dict[model_name]
has_model_class_in_config = hasattr(config, "architectures") and len(config.architectures) > 0 has_model_class_in_config = (
hasattr(config, "architectures")
and isinstance(config.architectures, list)
and len(config.architectures) > 0
)
if not self.args.only_pretrain_model and has_model_class_in_config: if not self.args.only_pretrain_model and has_model_class_in_config:
try: try:
model_class = config.architectures[0] model_class = config.architectures[0]

Просмотреть файл

@ -132,7 +132,11 @@ class TensorFlowBenchmark(Benchmark):
if self.args.fp16: if self.args.fp16:
raise NotImplementedError("Mixed precision is currently not supported.") raise NotImplementedError("Mixed precision is currently not supported.")
has_model_class_in_config = hasattr(config, "architectures") and len(config.architectures) > 0 has_model_class_in_config = (
hasattr(config, "architectures")
and isinstance(config.architectures, list)
and len(config.architectures) > 0
)
if not self.args.only_pretrain_model and has_model_class_in_config: if not self.args.only_pretrain_model and has_model_class_in_config:
try: try:
model_class = "TF" + config.architectures[0] # prepend 'TF' for tensorflow model model_class = "TF" + config.architectures[0] # prepend 'TF' for tensorflow model
@ -172,7 +176,11 @@ class TensorFlowBenchmark(Benchmark):
if self.args.fp16: if self.args.fp16:
raise NotImplementedError("Mixed precision is currently not supported.") raise NotImplementedError("Mixed precision is currently not supported.")
has_model_class_in_config = hasattr(config, "architectures") and len(config.architectures) > 0 has_model_class_in_config = (
hasattr(config, "architectures")
and isinstance(config.architectures, list)
and len(config.architectures) > 0
)
if not self.args.only_pretrain_model and has_model_class_in_config: if not self.args.only_pretrain_model and has_model_class_in_config:
try: try:
model_class = "TF" + config.architectures[0] # prepend 'TF' for tensorflow model model_class = "TF" + config.architectures[0] # prepend 'TF' for tensorflow model

Просмотреть файл

@ -86,6 +86,24 @@ class BenchmarkTest(unittest.TestCase):
self.check_results_dict_not_empty(results.time_inference_result) self.check_results_dict_not_empty(results.time_inference_result)
self.check_results_dict_not_empty(results.memory_inference_result) self.check_results_dict_not_empty(results.memory_inference_result)
def test_inference_no_model_no_architecuters(self):
MODEL_ID = "sshleifer/tiny-gpt2"
config = AutoConfig.from_pretrained(MODEL_ID)
# set architectures equal to `None`
config.architectures = None
benchmark_args = PyTorchBenchmarkArguments(
models=[MODEL_ID],
training=True,
no_inference=False,
sequence_lengths=[8],
batch_sizes=[1],
no_multi_process=True,
)
benchmark = PyTorchBenchmark(benchmark_args, configs=[config])
results = benchmark.run()
self.check_results_dict_not_empty(results.time_inference_result)
self.check_results_dict_not_empty(results.memory_inference_result)
def test_train_no_configs(self): def test_train_no_configs(self):
MODEL_ID = "sshleifer/tiny-gpt2" MODEL_ID = "sshleifer/tiny-gpt2"
benchmark_args = PyTorchBenchmarkArguments( benchmark_args = PyTorchBenchmarkArguments(