fix benchmark for longformer (#5808)
This commit is contained in:
Родитель
aefc0c0429
Коммит
89a78be51f
|
@ -88,7 +88,11 @@ class PyTorchBenchmark(Benchmark):
|
||||||
if self.args.torchscript:
|
if self.args.torchscript:
|
||||||
config.torchscript = True
|
config.torchscript = True
|
||||||
|
|
||||||
has_model_class_in_config = hasattr(config, "architectures") and len(config.architectures) > 0
|
has_model_class_in_config = (
|
||||||
|
hasattr(config, "architectures")
|
||||||
|
and isinstance(config.architectures, list)
|
||||||
|
and len(config.architectures) > 0
|
||||||
|
)
|
||||||
if not self.args.only_pretrain_model and has_model_class_in_config:
|
if not self.args.only_pretrain_model and has_model_class_in_config:
|
||||||
try:
|
try:
|
||||||
model_class = config.architectures[0]
|
model_class = config.architectures[0]
|
||||||
|
@ -138,7 +142,11 @@ class PyTorchBenchmark(Benchmark):
|
||||||
def _prepare_train_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
|
def _prepare_train_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
|
||||||
config = self.config_dict[model_name]
|
config = self.config_dict[model_name]
|
||||||
|
|
||||||
has_model_class_in_config = hasattr(config, "architectures") and len(config.architectures) > 0
|
has_model_class_in_config = (
|
||||||
|
hasattr(config, "architectures")
|
||||||
|
and isinstance(config.architectures, list)
|
||||||
|
and len(config.architectures) > 0
|
||||||
|
)
|
||||||
if not self.args.only_pretrain_model and has_model_class_in_config:
|
if not self.args.only_pretrain_model and has_model_class_in_config:
|
||||||
try:
|
try:
|
||||||
model_class = config.architectures[0]
|
model_class = config.architectures[0]
|
||||||
|
|
|
@ -132,7 +132,11 @@ class TensorFlowBenchmark(Benchmark):
|
||||||
if self.args.fp16:
|
if self.args.fp16:
|
||||||
raise NotImplementedError("Mixed precision is currently not supported.")
|
raise NotImplementedError("Mixed precision is currently not supported.")
|
||||||
|
|
||||||
has_model_class_in_config = hasattr(config, "architectures") and len(config.architectures) > 0
|
has_model_class_in_config = (
|
||||||
|
hasattr(config, "architectures")
|
||||||
|
and isinstance(config.architectures, list)
|
||||||
|
and len(config.architectures) > 0
|
||||||
|
)
|
||||||
if not self.args.only_pretrain_model and has_model_class_in_config:
|
if not self.args.only_pretrain_model and has_model_class_in_config:
|
||||||
try:
|
try:
|
||||||
model_class = "TF" + config.architectures[0] # prepend 'TF' for tensorflow model
|
model_class = "TF" + config.architectures[0] # prepend 'TF' for tensorflow model
|
||||||
|
@ -172,7 +176,11 @@ class TensorFlowBenchmark(Benchmark):
|
||||||
if self.args.fp16:
|
if self.args.fp16:
|
||||||
raise NotImplementedError("Mixed precision is currently not supported.")
|
raise NotImplementedError("Mixed precision is currently not supported.")
|
||||||
|
|
||||||
has_model_class_in_config = hasattr(config, "architectures") and len(config.architectures) > 0
|
has_model_class_in_config = (
|
||||||
|
hasattr(config, "architectures")
|
||||||
|
and isinstance(config.architectures, list)
|
||||||
|
and len(config.architectures) > 0
|
||||||
|
)
|
||||||
if not self.args.only_pretrain_model and has_model_class_in_config:
|
if not self.args.only_pretrain_model and has_model_class_in_config:
|
||||||
try:
|
try:
|
||||||
model_class = "TF" + config.architectures[0] # prepend 'TF' for tensorflow model
|
model_class = "TF" + config.architectures[0] # prepend 'TF' for tensorflow model
|
||||||
|
|
|
@ -86,6 +86,24 @@ class BenchmarkTest(unittest.TestCase):
|
||||||
self.check_results_dict_not_empty(results.time_inference_result)
|
self.check_results_dict_not_empty(results.time_inference_result)
|
||||||
self.check_results_dict_not_empty(results.memory_inference_result)
|
self.check_results_dict_not_empty(results.memory_inference_result)
|
||||||
|
|
||||||
|
def test_inference_no_model_no_architecuters(self):
|
||||||
|
MODEL_ID = "sshleifer/tiny-gpt2"
|
||||||
|
config = AutoConfig.from_pretrained(MODEL_ID)
|
||||||
|
# set architectures equal to `None`
|
||||||
|
config.architectures = None
|
||||||
|
benchmark_args = PyTorchBenchmarkArguments(
|
||||||
|
models=[MODEL_ID],
|
||||||
|
training=True,
|
||||||
|
no_inference=False,
|
||||||
|
sequence_lengths=[8],
|
||||||
|
batch_sizes=[1],
|
||||||
|
no_multi_process=True,
|
||||||
|
)
|
||||||
|
benchmark = PyTorchBenchmark(benchmark_args, configs=[config])
|
||||||
|
results = benchmark.run()
|
||||||
|
self.check_results_dict_not_empty(results.time_inference_result)
|
||||||
|
self.check_results_dict_not_empty(results.memory_inference_result)
|
||||||
|
|
||||||
def test_train_no_configs(self):
|
def test_train_no_configs(self):
|
||||||
MODEL_ID = "sshleifer/tiny-gpt2"
|
MODEL_ID = "sshleifer/tiny-gpt2"
|
||||||
benchmark_args = PyTorchBenchmarkArguments(
|
benchmark_args = PyTorchBenchmarkArguments(
|
||||||
|
|
Загрузка…
Ссылка в новой задаче