Rename question answering evaluation script.
This commit is contained in:
Родитель
027bff43d8
Коммит
9d25215619
|
@ -85,7 +85,7 @@
|
||||||
" postprocess_bert_answer\n",
|
" postprocess_bert_answer\n",
|
||||||
")\n",
|
")\n",
|
||||||
" \n",
|
" \n",
|
||||||
"from utils_nlp.eval.evaluate_question_answering import evaluate_qa\n",
|
"from utils_nlp.eval.question_answering import evaluate_qa\n",
|
||||||
"from utils_nlp.common.timer import Timer"
|
"from utils_nlp.common.timer import Timer"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -12067,9 +12067,9 @@
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"celltoolbar": "Tags",
|
"celltoolbar": "Tags",
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "nlp_gpu",
|
"display_name": "nlp_cpu",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "nlp_gpu"
|
"name": "nlp_cpu"
|
||||||
},
|
},
|
||||||
"language_info": {
|
"language_info": {
|
||||||
"codemirror_mode": {
|
"codemirror_mode": {
|
||||||
|
|
|
@ -15,7 +15,7 @@ from utils_nlp.models.transformers.question_answering import (
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def qa_test_data(qa_test_df, tmp_path):
|
def qa_test_data(qa_test_df, tmp):
|
||||||
|
|
||||||
train_dataset = QADataset(
|
train_dataset = QADataset(
|
||||||
df=qa_test_df["test_df"],
|
df=qa_test_df["test_df"],
|
||||||
|
@ -40,7 +40,7 @@ def qa_test_data(qa_test_df, tmp_path):
|
||||||
max_question_length=16,
|
max_question_length=16,
|
||||||
max_seq_length=64,
|
max_seq_length=64,
|
||||||
doc_stride=32,
|
doc_stride=32,
|
||||||
cache_dir=tmp_path,
|
cache_dir=tmp,
|
||||||
)
|
)
|
||||||
|
|
||||||
test_features_bert = qa_processor_bert.preprocess(
|
test_features_bert = qa_processor_bert.preprocess(
|
||||||
|
@ -49,7 +49,7 @@ def qa_test_data(qa_test_df, tmp_path):
|
||||||
max_question_length=16,
|
max_question_length=16,
|
||||||
max_seq_length=64,
|
max_seq_length=64,
|
||||||
doc_stride=32,
|
doc_stride=32,
|
||||||
cache_dir=tmp_path,
|
cache_dir=tmp,
|
||||||
)
|
)
|
||||||
|
|
||||||
qa_processor_xlnet = QAProcessor(model_name="xlnet-base-cased")
|
qa_processor_xlnet = QAProcessor(model_name="xlnet-base-cased")
|
||||||
|
@ -59,7 +59,7 @@ def qa_test_data(qa_test_df, tmp_path):
|
||||||
max_question_length=16,
|
max_question_length=16,
|
||||||
max_seq_length=64,
|
max_seq_length=64,
|
||||||
doc_stride=32,
|
doc_stride=32,
|
||||||
cache_dir=tmp_path,
|
cache_dir=tmp,
|
||||||
)
|
)
|
||||||
|
|
||||||
test_features_xlnet = qa_processor_xlnet.preprocess(
|
test_features_xlnet = qa_processor_xlnet.preprocess(
|
||||||
|
@ -68,7 +68,7 @@ def qa_test_data(qa_test_df, tmp_path):
|
||||||
max_question_length=16,
|
max_question_length=16,
|
||||||
max_seq_length=64,
|
max_seq_length=64,
|
||||||
doc_stride=32,
|
doc_stride=32,
|
||||||
cache_dir=tmp_path,
|
cache_dir=tmp,
|
||||||
)
|
)
|
||||||
|
|
||||||
qa_processor_distilbert = QAProcessor(model_name="distilbert-base-uncased")
|
qa_processor_distilbert = QAProcessor(model_name="distilbert-base-uncased")
|
||||||
|
@ -78,7 +78,7 @@ def qa_test_data(qa_test_df, tmp_path):
|
||||||
max_question_length=16,
|
max_question_length=16,
|
||||||
max_seq_length=64,
|
max_seq_length=64,
|
||||||
doc_stride=32,
|
doc_stride=32,
|
||||||
cache_dir=tmp_path,
|
cache_dir=tmp,
|
||||||
)
|
)
|
||||||
|
|
||||||
test_features_distilbert = qa_processor_distilbert.preprocess(
|
test_features_distilbert = qa_processor_distilbert.preprocess(
|
||||||
|
@ -87,7 +87,7 @@ def qa_test_data(qa_test_df, tmp_path):
|
||||||
max_question_length=16,
|
max_question_length=16,
|
||||||
max_seq_length=64,
|
max_seq_length=64,
|
||||||
doc_stride=32,
|
doc_stride=32,
|
||||||
cache_dir=tmp_path,
|
cache_dir=tmp,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
@ -102,7 +102,7 @@ def qa_test_data(qa_test_df, tmp_path):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_QAProcessor(qa_test_data, tmp_path):
|
def test_QAProcessor(qa_test_data, tmp):
|
||||||
for model_name in ["bert-base-cased", "xlnet-base-cased", "distilbert-base-uncased"]:
|
for model_name in ["bert-base-cased", "xlnet-base-cased", "distilbert-base-uncased"]:
|
||||||
qa_processor = QAProcessor(model_name=model_name)
|
qa_processor = QAProcessor(model_name=model_name)
|
||||||
qa_processor.preprocess(qa_test_data["train_dataset"], is_training=True)
|
qa_processor.preprocess(qa_test_data["train_dataset"], is_training=True)
|
||||||
|
@ -117,31 +117,31 @@ def test_QAProcessor(qa_test_data, tmp_path):
|
||||||
qa_processor.preprocess(qa_test_data["test_dataset"], is_training=True)
|
qa_processor.preprocess(qa_test_data["test_dataset"], is_training=True)
|
||||||
|
|
||||||
|
|
||||||
def test_AnswerExtractor(qa_test_data, tmp_path):
|
def test_AnswerExtractor(qa_test_data, tmp):
|
||||||
# test bert
|
# test bert
|
||||||
qa_extractor_bert = AnswerExtractor(cache_dir=tmp_path)
|
qa_extractor_bert = AnswerExtractor(cache_dir=tmp)
|
||||||
qa_extractor_bert.fit(
|
qa_extractor_bert.fit(
|
||||||
qa_test_data["train_features_bert"], cache_model=True, per_gpu_batch_size=8
|
qa_test_data["train_features_bert"], cache_model=True, per_gpu_batch_size=8
|
||||||
)
|
)
|
||||||
|
|
||||||
# test saving fine-tuned model
|
# test saving fine-tuned model
|
||||||
model_output_dir = os.path.join(tmp_path, "fine_tuned")
|
model_output_dir = os.path.join(tmp, "fine_tuned")
|
||||||
assert os.path.exists(os.path.join(model_output_dir, "pytorch_model.bin"))
|
assert os.path.exists(os.path.join(model_output_dir, "pytorch_model.bin"))
|
||||||
assert os.path.exists(os.path.join(model_output_dir, "config.json"))
|
assert os.path.exists(os.path.join(model_output_dir, "config.json"))
|
||||||
|
|
||||||
qa_extractor_from_cache = AnswerExtractor(
|
qa_extractor_from_cache = AnswerExtractor(
|
||||||
cache_dir=tmp_path, load_model_from_dir=model_output_dir
|
cache_dir=tmp, load_model_from_dir=model_output_dir
|
||||||
)
|
)
|
||||||
qa_extractor_from_cache.predict(qa_test_data["test_features_bert"])
|
qa_extractor_from_cache.predict(qa_test_data["test_features_bert"])
|
||||||
|
|
||||||
qa_extractor_xlnet = AnswerExtractor(model_name="xlnet-base-cased", cache_dir=tmp_path)
|
qa_extractor_xlnet = AnswerExtractor(model_name="xlnet-base-cased", cache_dir=tmp)
|
||||||
qa_extractor_xlnet.fit(
|
qa_extractor_xlnet.fit(
|
||||||
qa_test_data["train_features_xlnet"], cache_model=False, per_gpu_batch_size=8
|
qa_test_data["train_features_xlnet"], cache_model=False, per_gpu_batch_size=8
|
||||||
)
|
)
|
||||||
qa_extractor_xlnet.predict(qa_test_data["test_features_xlnet"])
|
qa_extractor_xlnet.predict(qa_test_data["test_features_xlnet"])
|
||||||
|
|
||||||
qa_extractor_distilbert = AnswerExtractor(
|
qa_extractor_distilbert = AnswerExtractor(
|
||||||
model_name="distilbert-base-uncased", cache_dir=tmp_path
|
model_name="distilbert-base-uncased", cache_dir=tmp
|
||||||
)
|
)
|
||||||
qa_extractor_distilbert.fit(
|
qa_extractor_distilbert.fit(
|
||||||
qa_test_data["train_features_distilbert"], cache_model=False, per_gpu_batch_size=8
|
qa_test_data["train_features_distilbert"], cache_model=False, per_gpu_batch_size=8
|
||||||
|
@ -149,7 +149,7 @@ def test_AnswerExtractor(qa_test_data, tmp_path):
|
||||||
qa_extractor_distilbert.predict(qa_test_data["test_features_distilbert"])
|
qa_extractor_distilbert.predict(qa_test_data["test_features_distilbert"])
|
||||||
|
|
||||||
|
|
||||||
def test_postprocess_bert_answer(qa_test_data, tmp_path):
|
def test_postprocess_bert_answer(qa_test_data, tmp):
|
||||||
qa_processor = QAProcessor()
|
qa_processor = QAProcessor()
|
||||||
test_features = qa_processor.preprocess(
|
test_features = qa_processor.preprocess(
|
||||||
qa_test_data["test_dataset"],
|
qa_test_data["test_dataset"],
|
||||||
|
@ -157,29 +157,29 @@ def test_postprocess_bert_answer(qa_test_data, tmp_path):
|
||||||
max_question_length=16,
|
max_question_length=16,
|
||||||
max_seq_length=64,
|
max_seq_length=64,
|
||||||
doc_stride=32,
|
doc_stride=32,
|
||||||
cache_dir=tmp_path,
|
cache_dir=tmp,
|
||||||
)
|
)
|
||||||
qa_extractor = AnswerExtractor(cache_dir=tmp_path)
|
qa_extractor = AnswerExtractor(cache_dir=tmp)
|
||||||
predictions = qa_extractor.predict(test_features)
|
predictions = qa_extractor.predict(test_features)
|
||||||
|
|
||||||
postprocess_bert_answer(
|
postprocess_bert_answer(
|
||||||
results=predictions,
|
results=predictions,
|
||||||
examples_file=os.path.join(tmp_path, CACHED_EXAMPLES_TEST_FILE),
|
examples_file=os.path.join(tmp, CACHED_EXAMPLES_TEST_FILE),
|
||||||
features_file=os.path.join(tmp_path, CACHED_FEATURES_TEST_FILE),
|
features_file=os.path.join(tmp, CACHED_FEATURES_TEST_FILE),
|
||||||
do_lower_case=False,
|
do_lower_case=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
postprocess_bert_answer(
|
postprocess_bert_answer(
|
||||||
results=predictions,
|
results=predictions,
|
||||||
examples_file=os.path.join(tmp_path, CACHED_EXAMPLES_TEST_FILE),
|
examples_file=os.path.join(tmp, CACHED_EXAMPLES_TEST_FILE),
|
||||||
features_file=os.path.join(tmp_path, CACHED_FEATURES_TEST_FILE),
|
features_file=os.path.join(tmp, CACHED_FEATURES_TEST_FILE),
|
||||||
do_lower_case=False,
|
do_lower_case=False,
|
||||||
unanswerable_exists=True,
|
unanswerable_exists=True,
|
||||||
verbose_logging=True,
|
verbose_logging=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_postprocess_xlnet_answer(qa_test_data, tmp_path):
|
def test_postprocess_xlnet_answer(qa_test_data, tmp):
|
||||||
qa_processor = QAProcessor(model_name="xlnet-base-cased")
|
qa_processor = QAProcessor(model_name="xlnet-base-cased")
|
||||||
test_features = qa_processor.preprocess(
|
test_features = qa_processor.preprocess(
|
||||||
qa_test_data["test_dataset"],
|
qa_test_data["test_dataset"],
|
||||||
|
@ -187,24 +187,24 @@ def test_postprocess_xlnet_answer(qa_test_data, tmp_path):
|
||||||
max_question_length=16,
|
max_question_length=16,
|
||||||
max_seq_length=64,
|
max_seq_length=64,
|
||||||
doc_stride=32,
|
doc_stride=32,
|
||||||
cache_dir=tmp_path,
|
cache_dir=tmp,
|
||||||
)
|
)
|
||||||
qa_extractor = AnswerExtractor(model_name="xlnet-base-cased", cache_dir=tmp_path)
|
qa_extractor = AnswerExtractor(model_name="xlnet-base-cased", cache_dir=tmp)
|
||||||
predictions = qa_extractor.predict(test_features)
|
predictions = qa_extractor.predict(test_features)
|
||||||
|
|
||||||
postprocess_xlnet_answer(
|
postprocess_xlnet_answer(
|
||||||
model_name="xlnet-base-cased",
|
model_name="xlnet-base-cased",
|
||||||
results=predictions,
|
results=predictions,
|
||||||
examples_file=os.path.join(tmp_path, CACHED_EXAMPLES_TEST_FILE),
|
examples_file=os.path.join(tmp, CACHED_EXAMPLES_TEST_FILE),
|
||||||
features_file=os.path.join(tmp_path, CACHED_FEATURES_TEST_FILE),
|
features_file=os.path.join(tmp, CACHED_FEATURES_TEST_FILE),
|
||||||
do_lower_case=False,
|
do_lower_case=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
postprocess_xlnet_answer(
|
postprocess_xlnet_answer(
|
||||||
model_name="xlnet-base-cased",
|
model_name="xlnet-base-cased",
|
||||||
results=predictions,
|
results=predictions,
|
||||||
examples_file=os.path.join(tmp_path, CACHED_EXAMPLES_TEST_FILE),
|
examples_file=os.path.join(tmp, CACHED_EXAMPLES_TEST_FILE),
|
||||||
features_file=os.path.join(tmp_path, CACHED_FEATURES_TEST_FILE),
|
features_file=os.path.join(tmp, CACHED_FEATURES_TEST_FILE),
|
||||||
do_lower_case=False,
|
do_lower_case=False,
|
||||||
unanswerable_exists=True,
|
unanswerable_exists=True,
|
||||||
verbose_logging=True,
|
verbose_logging=True,
|
||||||
|
|
Загрузка…
Ссылка в новой задаче