Moved postprocess function into QAProcessor.
This commit is contained in:
Родитель
16424fe60f
Коммит
4677d0b6e4
|
@ -40,7 +40,7 @@ def qa_test_data(qa_test_df, tmp):
|
|||
max_question_length=16,
|
||||
max_seq_length=64,
|
||||
doc_stride=32,
|
||||
cache_dir=tmp,
|
||||
feature_cache_dir=tmp,
|
||||
)
|
||||
|
||||
test_features_bert = qa_processor_bert.preprocess(
|
||||
|
@ -49,7 +49,7 @@ def qa_test_data(qa_test_df, tmp):
|
|||
max_question_length=16,
|
||||
max_seq_length=64,
|
||||
doc_stride=32,
|
||||
cache_dir=tmp,
|
||||
feature_cache_dir=tmp,
|
||||
)
|
||||
|
||||
qa_processor_xlnet = QAProcessor(model_name="xlnet-base-cased")
|
||||
|
@ -59,7 +59,7 @@ def qa_test_data(qa_test_df, tmp):
|
|||
max_question_length=16,
|
||||
max_seq_length=64,
|
||||
doc_stride=32,
|
||||
cache_dir=tmp,
|
||||
feature_cache_dir=tmp,
|
||||
)
|
||||
|
||||
test_features_xlnet = qa_processor_xlnet.preprocess(
|
||||
|
@ -68,7 +68,7 @@ def qa_test_data(qa_test_df, tmp):
|
|||
max_question_length=16,
|
||||
max_seq_length=64,
|
||||
doc_stride=32,
|
||||
cache_dir=tmp,
|
||||
feature_cache_dir=tmp,
|
||||
)
|
||||
|
||||
qa_processor_distilbert = QAProcessor(model_name="distilbert-base-uncased")
|
||||
|
@ -78,7 +78,7 @@ def qa_test_data(qa_test_df, tmp):
|
|||
max_question_length=16,
|
||||
max_seq_length=64,
|
||||
doc_stride=32,
|
||||
cache_dir=tmp,
|
||||
feature_cache_dir=tmp,
|
||||
)
|
||||
|
||||
test_features_distilbert = qa_processor_distilbert.preprocess(
|
||||
|
@ -87,7 +87,7 @@ def qa_test_data(qa_test_df, tmp):
|
|||
max_question_length=16,
|
||||
max_seq_length=64,
|
||||
doc_stride=32,
|
||||
cache_dir=tmp,
|
||||
feature_cache_dir=tmp,
|
||||
)
|
||||
|
||||
return {
|
||||
|
@ -157,23 +157,21 @@ def test_postprocess_bert_answer(qa_test_data, tmp):
|
|||
max_question_length=16,
|
||||
max_seq_length=64,
|
||||
doc_stride=32,
|
||||
cache_dir=tmp,
|
||||
feature_cache_dir=tmp,
|
||||
)
|
||||
qa_extractor = AnswerExtractor(cache_dir=tmp)
|
||||
predictions = qa_extractor.predict(test_features)
|
||||
|
||||
postprocess_bert_answer(
|
||||
qa_processor.postprocess(
|
||||
results=predictions,
|
||||
examples_file=os.path.join(tmp, CACHED_EXAMPLES_TEST_FILE),
|
||||
features_file=os.path.join(tmp, CACHED_FEATURES_TEST_FILE),
|
||||
do_lower_case=False,
|
||||
)
|
||||
|
||||
postprocess_bert_answer(
|
||||
qa_processor.postprocess(
|
||||
results=predictions,
|
||||
examples_file=os.path.join(tmp, CACHED_EXAMPLES_TEST_FILE),
|
||||
features_file=os.path.join(tmp, CACHED_FEATURES_TEST_FILE),
|
||||
do_lower_case=False,
|
||||
unanswerable_exists=True,
|
||||
verbose_logging=True,
|
||||
)
|
||||
|
@ -187,25 +185,21 @@ def test_postprocess_xlnet_answer(qa_test_data, tmp):
|
|||
max_question_length=16,
|
||||
max_seq_length=64,
|
||||
doc_stride=32,
|
||||
cache_dir=tmp,
|
||||
feature_cache_dir=tmp,
|
||||
)
|
||||
qa_extractor = AnswerExtractor(model_name="xlnet-base-cased", cache_dir=tmp)
|
||||
predictions = qa_extractor.predict(test_features)
|
||||
|
||||
postprocess_xlnet_answer(
|
||||
model_name="xlnet-base-cased",
|
||||
qa_processor.postprocess(
|
||||
results=predictions,
|
||||
examples_file=os.path.join(tmp, CACHED_EXAMPLES_TEST_FILE),
|
||||
features_file=os.path.join(tmp, CACHED_FEATURES_TEST_FILE),
|
||||
do_lower_case=False,
|
||||
)
|
||||
|
||||
postprocess_xlnet_answer(
|
||||
model_name="xlnet-base-cased",
|
||||
qa_processor.postprocess(
|
||||
results=predictions,
|
||||
examples_file=os.path.join(tmp, CACHED_EXAMPLES_TEST_FILE),
|
||||
features_file=os.path.join(tmp, CACHED_FEATURES_TEST_FILE),
|
||||
do_lower_case=False,
|
||||
unanswerable_exists=True,
|
||||
verbose_logging=True,
|
||||
)
|
||||
|
|
Загрузка…
Ссылка в новой задаче