Rename question answering evaluation script.

This commit is contained in:
hlums 2019-10-17 17:36:40 +00:00
Родитель 027bff43d8
Коммит 9d25215619
3 изменённых файлов: 31 добавлений и 31 удалений

Просмотреть файл

@ -85,7 +85,7 @@
" postprocess_bert_answer\n", " postprocess_bert_answer\n",
")\n", ")\n",
" \n", " \n",
"from utils_nlp.eval.evaluate_question_answering import evaluate_qa\n", "from utils_nlp.eval.question_answering import evaluate_qa\n",
"from utils_nlp.common.timer import Timer" "from utils_nlp.common.timer import Timer"
] ]
}, },
@ -12067,9 +12067,9 @@
"metadata": { "metadata": {
"celltoolbar": "Tags", "celltoolbar": "Tags",
"kernelspec": { "kernelspec": {
"display_name": "nlp_gpu", "display_name": "nlp_cpu",
"language": "python", "language": "python",
"name": "nlp_gpu" "name": "nlp_cpu"
}, },
"language_info": { "language_info": {
"codemirror_mode": { "codemirror_mode": {

Просмотреть файл

@ -15,7 +15,7 @@ from utils_nlp.models.transformers.question_answering import (
@pytest.fixture() @pytest.fixture()
def qa_test_data(qa_test_df, tmp_path): def qa_test_data(qa_test_df, tmp):
train_dataset = QADataset( train_dataset = QADataset(
df=qa_test_df["test_df"], df=qa_test_df["test_df"],
@ -40,7 +40,7 @@ def qa_test_data(qa_test_df, tmp_path):
max_question_length=16, max_question_length=16,
max_seq_length=64, max_seq_length=64,
doc_stride=32, doc_stride=32,
cache_dir=tmp_path, cache_dir=tmp,
) )
test_features_bert = qa_processor_bert.preprocess( test_features_bert = qa_processor_bert.preprocess(
@ -49,7 +49,7 @@ def qa_test_data(qa_test_df, tmp_path):
max_question_length=16, max_question_length=16,
max_seq_length=64, max_seq_length=64,
doc_stride=32, doc_stride=32,
cache_dir=tmp_path, cache_dir=tmp,
) )
qa_processor_xlnet = QAProcessor(model_name="xlnet-base-cased") qa_processor_xlnet = QAProcessor(model_name="xlnet-base-cased")
@ -59,7 +59,7 @@ def qa_test_data(qa_test_df, tmp_path):
max_question_length=16, max_question_length=16,
max_seq_length=64, max_seq_length=64,
doc_stride=32, doc_stride=32,
cache_dir=tmp_path, cache_dir=tmp,
) )
test_features_xlnet = qa_processor_xlnet.preprocess( test_features_xlnet = qa_processor_xlnet.preprocess(
@ -68,7 +68,7 @@ def qa_test_data(qa_test_df, tmp_path):
max_question_length=16, max_question_length=16,
max_seq_length=64, max_seq_length=64,
doc_stride=32, doc_stride=32,
cache_dir=tmp_path, cache_dir=tmp,
) )
qa_processor_distilbert = QAProcessor(model_name="distilbert-base-uncased") qa_processor_distilbert = QAProcessor(model_name="distilbert-base-uncased")
@ -78,7 +78,7 @@ def qa_test_data(qa_test_df, tmp_path):
max_question_length=16, max_question_length=16,
max_seq_length=64, max_seq_length=64,
doc_stride=32, doc_stride=32,
cache_dir=tmp_path, cache_dir=tmp,
) )
test_features_distilbert = qa_processor_distilbert.preprocess( test_features_distilbert = qa_processor_distilbert.preprocess(
@ -87,7 +87,7 @@ def qa_test_data(qa_test_df, tmp_path):
max_question_length=16, max_question_length=16,
max_seq_length=64, max_seq_length=64,
doc_stride=32, doc_stride=32,
cache_dir=tmp_path, cache_dir=tmp,
) )
return { return {
@ -102,7 +102,7 @@ def qa_test_data(qa_test_df, tmp_path):
} }
def test_QAProcessor(qa_test_data, tmp_path): def test_QAProcessor(qa_test_data, tmp):
for model_name in ["bert-base-cased", "xlnet-base-cased", "distilbert-base-uncased"]: for model_name in ["bert-base-cased", "xlnet-base-cased", "distilbert-base-uncased"]:
qa_processor = QAProcessor(model_name=model_name) qa_processor = QAProcessor(model_name=model_name)
qa_processor.preprocess(qa_test_data["train_dataset"], is_training=True) qa_processor.preprocess(qa_test_data["train_dataset"], is_training=True)
@ -117,31 +117,31 @@ def test_QAProcessor(qa_test_data, tmp_path):
qa_processor.preprocess(qa_test_data["test_dataset"], is_training=True) qa_processor.preprocess(qa_test_data["test_dataset"], is_training=True)
def test_AnswerExtractor(qa_test_data, tmp_path): def test_AnswerExtractor(qa_test_data, tmp):
# test bert # test bert
qa_extractor_bert = AnswerExtractor(cache_dir=tmp_path) qa_extractor_bert = AnswerExtractor(cache_dir=tmp)
qa_extractor_bert.fit( qa_extractor_bert.fit(
qa_test_data["train_features_bert"], cache_model=True, per_gpu_batch_size=8 qa_test_data["train_features_bert"], cache_model=True, per_gpu_batch_size=8
) )
# test saving fine-tuned model # test saving fine-tuned model
model_output_dir = os.path.join(tmp_path, "fine_tuned") model_output_dir = os.path.join(tmp, "fine_tuned")
assert os.path.exists(os.path.join(model_output_dir, "pytorch_model.bin")) assert os.path.exists(os.path.join(model_output_dir, "pytorch_model.bin"))
assert os.path.exists(os.path.join(model_output_dir, "config.json")) assert os.path.exists(os.path.join(model_output_dir, "config.json"))
qa_extractor_from_cache = AnswerExtractor( qa_extractor_from_cache = AnswerExtractor(
cache_dir=tmp_path, load_model_from_dir=model_output_dir cache_dir=tmp, load_model_from_dir=model_output_dir
) )
qa_extractor_from_cache.predict(qa_test_data["test_features_bert"]) qa_extractor_from_cache.predict(qa_test_data["test_features_bert"])
qa_extractor_xlnet = AnswerExtractor(model_name="xlnet-base-cased", cache_dir=tmp_path) qa_extractor_xlnet = AnswerExtractor(model_name="xlnet-base-cased", cache_dir=tmp)
qa_extractor_xlnet.fit( qa_extractor_xlnet.fit(
qa_test_data["train_features_xlnet"], cache_model=False, per_gpu_batch_size=8 qa_test_data["train_features_xlnet"], cache_model=False, per_gpu_batch_size=8
) )
qa_extractor_xlnet.predict(qa_test_data["test_features_xlnet"]) qa_extractor_xlnet.predict(qa_test_data["test_features_xlnet"])
qa_extractor_distilbert = AnswerExtractor( qa_extractor_distilbert = AnswerExtractor(
model_name="distilbert-base-uncased", cache_dir=tmp_path model_name="distilbert-base-uncased", cache_dir=tmp
) )
qa_extractor_distilbert.fit( qa_extractor_distilbert.fit(
qa_test_data["train_features_distilbert"], cache_model=False, per_gpu_batch_size=8 qa_test_data["train_features_distilbert"], cache_model=False, per_gpu_batch_size=8
@ -149,7 +149,7 @@ def test_AnswerExtractor(qa_test_data, tmp_path):
qa_extractor_distilbert.predict(qa_test_data["test_features_distilbert"]) qa_extractor_distilbert.predict(qa_test_data["test_features_distilbert"])
def test_postprocess_bert_answer(qa_test_data, tmp_path): def test_postprocess_bert_answer(qa_test_data, tmp):
qa_processor = QAProcessor() qa_processor = QAProcessor()
test_features = qa_processor.preprocess( test_features = qa_processor.preprocess(
qa_test_data["test_dataset"], qa_test_data["test_dataset"],
@ -157,29 +157,29 @@ def test_postprocess_bert_answer(qa_test_data, tmp_path):
max_question_length=16, max_question_length=16,
max_seq_length=64, max_seq_length=64,
doc_stride=32, doc_stride=32,
cache_dir=tmp_path, cache_dir=tmp,
) )
qa_extractor = AnswerExtractor(cache_dir=tmp_path) qa_extractor = AnswerExtractor(cache_dir=tmp)
predictions = qa_extractor.predict(test_features) predictions = qa_extractor.predict(test_features)
postprocess_bert_answer( postprocess_bert_answer(
results=predictions, results=predictions,
examples_file=os.path.join(tmp_path, CACHED_EXAMPLES_TEST_FILE), examples_file=os.path.join(tmp, CACHED_EXAMPLES_TEST_FILE),
features_file=os.path.join(tmp_path, CACHED_FEATURES_TEST_FILE), features_file=os.path.join(tmp, CACHED_FEATURES_TEST_FILE),
do_lower_case=False, do_lower_case=False,
) )
postprocess_bert_answer( postprocess_bert_answer(
results=predictions, results=predictions,
examples_file=os.path.join(tmp_path, CACHED_EXAMPLES_TEST_FILE), examples_file=os.path.join(tmp, CACHED_EXAMPLES_TEST_FILE),
features_file=os.path.join(tmp_path, CACHED_FEATURES_TEST_FILE), features_file=os.path.join(tmp, CACHED_FEATURES_TEST_FILE),
do_lower_case=False, do_lower_case=False,
unanswerable_exists=True, unanswerable_exists=True,
verbose_logging=True, verbose_logging=True,
) )
def test_postprocess_xlnet_answer(qa_test_data, tmp_path): def test_postprocess_xlnet_answer(qa_test_data, tmp):
qa_processor = QAProcessor(model_name="xlnet-base-cased") qa_processor = QAProcessor(model_name="xlnet-base-cased")
test_features = qa_processor.preprocess( test_features = qa_processor.preprocess(
qa_test_data["test_dataset"], qa_test_data["test_dataset"],
@ -187,24 +187,24 @@ def test_postprocess_xlnet_answer(qa_test_data, tmp_path):
max_question_length=16, max_question_length=16,
max_seq_length=64, max_seq_length=64,
doc_stride=32, doc_stride=32,
cache_dir=tmp_path, cache_dir=tmp,
) )
qa_extractor = AnswerExtractor(model_name="xlnet-base-cased", cache_dir=tmp_path) qa_extractor = AnswerExtractor(model_name="xlnet-base-cased", cache_dir=tmp)
predictions = qa_extractor.predict(test_features) predictions = qa_extractor.predict(test_features)
postprocess_xlnet_answer( postprocess_xlnet_answer(
model_name="xlnet-base-cased", model_name="xlnet-base-cased",
results=predictions, results=predictions,
examples_file=os.path.join(tmp_path, CACHED_EXAMPLES_TEST_FILE), examples_file=os.path.join(tmp, CACHED_EXAMPLES_TEST_FILE),
features_file=os.path.join(tmp_path, CACHED_FEATURES_TEST_FILE), features_file=os.path.join(tmp, CACHED_FEATURES_TEST_FILE),
do_lower_case=False, do_lower_case=False,
) )
postprocess_xlnet_answer( postprocess_xlnet_answer(
model_name="xlnet-base-cased", model_name="xlnet-base-cased",
results=predictions, results=predictions,
examples_file=os.path.join(tmp_path, CACHED_EXAMPLES_TEST_FILE), examples_file=os.path.join(tmp, CACHED_EXAMPLES_TEST_FILE),
features_file=os.path.join(tmp_path, CACHED_FEATURES_TEST_FILE), features_file=os.path.join(tmp, CACHED_FEATURES_TEST_FILE),
do_lower_case=False, do_lower_case=False,
unanswerable_exists=True, unanswerable_exists=True,
verbose_logging=True, verbose_logging=True,