Fix two bugs: 1. Index of test data of SST-2. 2. Label index of MNLI data. (#4546)

This commit is contained in:
Zhangyx 2020-05-29 23:12:24 +08:00 коммит произвёл GitHub
Родитель 9c17256447
Коммит 3a5d1ea2a5
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 13 добавлений и 10 удалений

Просмотреть файл

@ -86,6 +86,15 @@ class GlueDataset(Dataset):
mode.value, tokenizer.__class__.__name__, str(args.max_seq_length), args.task_name,
),
)
label_list = self.processor.get_labels()
if args.task_name in ["mnli", "mnli-mm"] and tokenizer.__class__ in (
RobertaTokenizer,
RobertaTokenizerFast,
XLMRobertaTokenizer,
):
# HACK(label indices are swapped in RoBERTa pretrained model)
label_list[1], label_list[2] = label_list[2], label_list[1]
self.label_list = label_list
# Make sure only the first process in distributed training processes the dataset,
# and the others will use the cache.
@ -100,14 +109,7 @@ class GlueDataset(Dataset):
)
else:
logger.info(f"Creating features from dataset file at {args.data_dir}")
label_list = self.processor.get_labels()
if args.task_name in ["mnli", "mnli-mm"] and tokenizer.__class__ in (
RobertaTokenizer,
RobertaTokenizerFast,
XLMRobertaTokenizer,
):
# HACK(label indices are swapped in RoBERTa pretrained model)
label_list[1], label_list[2] = label_list[2], label_list[1]
if mode == Split.dev:
examples = self.processor.get_dev_examples(args.data_dir)
elif mode == Split.test:
@ -137,4 +139,4 @@ class GlueDataset(Dataset):
return self.features[i]
def get_labels(self):
return self.processor.get_labels()
return self.label_list

Просмотреть файл

@ -332,11 +332,12 @@ class Sst2Processor(DataProcessor):
def _create_examples(self, lines, set_type):
"""Creates examples for the training, dev and test sets."""
examples = []
text_index = 1 if set_type == "test" else 0
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = line[0]
text_a = line[text_index]
label = None if set_type == "test" else line[1]
examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples