From 0533cf470659b97c6279bd04f65536a1ec88404a Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Thu, 9 Jul 2020 09:19:19 -0400 Subject: [PATCH] Test XLA examples (#5583) * Test XLA examples * Style * Using `require_torch_tpu` * Style * No need for pytest --- examples/test_xla_examples.py | 91 +++++++++++++++++++++++++++++++ src/transformers/testing_utils.py | 12 +++- 2 files changed, 102 insertions(+), 1 deletion(-) create mode 100644 examples/test_xla_examples.py diff --git a/examples/test_xla_examples.py b/examples/test_xla_examples.py new file mode 100644 index 000000000..c192a87e8 --- /dev/null +++ b/examples/test_xla_examples.py @@ -0,0 +1,91 @@ +# coding=utf-8 +# Copyright 2018 HuggingFace Inc.. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +import sys +import unittest +from time import time +from unittest.mock import patch + +from transformers.testing_utils import require_torch_tpu + + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() + + +def get_setup_file(): + parser = argparse.ArgumentParser() + parser.add_argument("-f") + args = parser.parse_args() + return args.f + + +@require_torch_tpu +class TorchXLAExamplesTests(unittest.TestCase): + def test_run_glue(self): + import xla_spawn + + stream_handler = logging.StreamHandler(sys.stdout) + logger.addHandler(stream_handler) + + output_directory = "run_glue_output" + + testargs = f""" + text-classification/run_glue.py + --num_cores=8 + text-classification/run_glue.py + --do_train + --do_eval + --task_name=MRPC + --data_dir=../glue_data/MRPC + --cache_dir=./cache_dir + --num_train_epochs=1 + --max_seq_length=128 + --learning_rate=3e-5 + --output_dir={output_directory} + --overwrite_output_dir + --logging_steps=5 + --save_steps=5 + --overwrite_cache + --tpu_metrics_debug + --model_name_or_path=bert-base-cased + --per_device_train_batch_size=64 + --per_device_eval_batch_size=64 + --evaluate_during_training + --overwrite_cache + """.split() + with patch.object(sys, "argv", testargs): + start = time() + xla_spawn.main() + end = time() + + result = {} + with open(f"{output_directory}/eval_results_mrpc.txt") as f: + lines = f.readlines() + for line in lines: + key, value = line.split(" = ") + result[key] = float(value) + + del result["eval_loss"] + for value in result.values(): + # Assert that the model trains + self.assertGreaterEqual(value, 0.70) + + # Assert that the script takes less than 100 seconds to make sure it doesn't hang. + self.assertLess(end - start, 100) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 556e3fbff..5c5c9dec4 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -2,7 +2,7 @@ import os import unittest from distutils.util import strtobool -from transformers.file_utils import _tf_available, _torch_available +from transformers.file_utils import _tf_available, _torch_available, _torch_tpu_available SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy" @@ -113,6 +113,16 @@ def require_multigpu(test_case): return test_case +def require_torch_tpu(test_case): + """ + Decorator marking a test that requires a TPU (in PyTorch). + """ + if not _torch_tpu_available: + return unittest.skip("test requires PyTorch TPU") + + return test_case + + if _torch_available: # Set the USE_CUDA environment variable to select a GPU. torch_device = "cuda" if parse_flag_from_env("USE_CUDA") else "cpu"