From 4677d0b6e44767e24b2136869ec92152fa2c339d Mon Sep 17 00:00:00 2001 From: hlums Date: Thu, 17 Oct 2019 18:57:05 +0000 Subject: [PATCH] Moved postprocess function into QAProcessor. --- ..._models_transformers_question_answering.py | 30 ++++++++----------- 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/tests/unit/test_models_transformers_question_answering.py b/tests/unit/test_models_transformers_question_answering.py index 51a63b5..074b28a 100644 --- a/tests/unit/test_models_transformers_question_answering.py +++ b/tests/unit/test_models_transformers_question_answering.py @@ -40,7 +40,7 @@ def qa_test_data(qa_test_df, tmp): max_question_length=16, max_seq_length=64, doc_stride=32, - cache_dir=tmp, + feature_cache_dir=tmp, ) test_features_bert = qa_processor_bert.preprocess( @@ -49,7 +49,7 @@ def qa_test_data(qa_test_df, tmp): max_question_length=16, max_seq_length=64, doc_stride=32, - cache_dir=tmp, + feature_cache_dir=tmp, ) qa_processor_xlnet = QAProcessor(model_name="xlnet-base-cased") @@ -59,7 +59,7 @@ def qa_test_data(qa_test_df, tmp): max_question_length=16, max_seq_length=64, doc_stride=32, - cache_dir=tmp, + feature_cache_dir=tmp, ) test_features_xlnet = qa_processor_xlnet.preprocess( @@ -68,7 +68,7 @@ def qa_test_data(qa_test_df, tmp): max_question_length=16, max_seq_length=64, doc_stride=32, - cache_dir=tmp, + feature_cache_dir=tmp, ) qa_processor_distilbert = QAProcessor(model_name="distilbert-base-uncased") @@ -78,7 +78,7 @@ def qa_test_data(qa_test_df, tmp): max_question_length=16, max_seq_length=64, doc_stride=32, - cache_dir=tmp, + feature_cache_dir=tmp, ) test_features_distilbert = qa_processor_distilbert.preprocess( @@ -87,7 +87,7 @@ def qa_test_data(qa_test_df, tmp): max_question_length=16, max_seq_length=64, doc_stride=32, - cache_dir=tmp, + feature_cache_dir=tmp, ) return { @@ -157,23 +157,21 @@ def test_postprocess_bert_answer(qa_test_data, tmp): max_question_length=16, max_seq_length=64, doc_stride=32, - cache_dir=tmp, + feature_cache_dir=tmp, ) qa_extractor = AnswerExtractor(cache_dir=tmp) predictions = qa_extractor.predict(test_features) - postprocess_bert_answer( + qa_processor.postprocess( results=predictions, examples_file=os.path.join(tmp, CACHED_EXAMPLES_TEST_FILE), features_file=os.path.join(tmp, CACHED_FEATURES_TEST_FILE), - do_lower_case=False, ) - postprocess_bert_answer( + qa_processor.postprocess( results=predictions, examples_file=os.path.join(tmp, CACHED_EXAMPLES_TEST_FILE), features_file=os.path.join(tmp, CACHED_FEATURES_TEST_FILE), - do_lower_case=False, unanswerable_exists=True, verbose_logging=True, ) @@ -187,25 +185,21 @@ def test_postprocess_xlnet_answer(qa_test_data, tmp): max_question_length=16, max_seq_length=64, doc_stride=32, - cache_dir=tmp, + feature_cache_dir=tmp, ) qa_extractor = AnswerExtractor(model_name="xlnet-base-cased", cache_dir=tmp) predictions = qa_extractor.predict(test_features) - postprocess_xlnet_answer( - model_name="xlnet-base-cased", + qa_processor.postprocess( results=predictions, examples_file=os.path.join(tmp, CACHED_EXAMPLES_TEST_FILE), features_file=os.path.join(tmp, CACHED_FEATURES_TEST_FILE), - do_lower_case=False, ) - postprocess_xlnet_answer( - model_name="xlnet-base-cased", + qa_processor.postprocess( results=predictions, examples_file=os.path.join(tmp, CACHED_EXAMPLES_TEST_FILE), features_file=os.path.join(tmp, CACHED_FEATURES_TEST_FILE), - do_lower_case=False, unanswerable_exists=True, verbose_logging=True, )