[bart-mnli] Fix class flipping bug (#5141)
This commit is contained in:
Родитель
e33929ef1e
Коммит
f45e873910
|
@ -9,6 +9,7 @@ import torch
|
||||||
from filelock import FileLock
|
from filelock import FileLock
|
||||||
from torch.utils.data.dataset import Dataset
|
from torch.utils.data.dataset import Dataset
|
||||||
|
|
||||||
|
from ...tokenization_bart import BartTokenizer, BartTokenizerFast
|
||||||
from ...tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
|
from ...tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
|
||||||
from ...tokenization_utils import PreTrainedTokenizer
|
from ...tokenization_utils import PreTrainedTokenizer
|
||||||
from ...tokenization_xlm_roberta import XLMRobertaTokenizer
|
from ...tokenization_xlm_roberta import XLMRobertaTokenizer
|
||||||
|
@ -92,6 +93,8 @@ class GlueDataset(Dataset):
|
||||||
RobertaTokenizer,
|
RobertaTokenizer,
|
||||||
RobertaTokenizerFast,
|
RobertaTokenizerFast,
|
||||||
XLMRobertaTokenizer,
|
XLMRobertaTokenizer,
|
||||||
|
BartTokenizer,
|
||||||
|
BartTokenizerFast,
|
||||||
):
|
):
|
||||||
# HACK(label indices are swapped in RoBERTa pretrained model)
|
# HACK(label indices are swapped in RoBERTa pretrained model)
|
||||||
label_list[1], label_list[2] = label_list[2], label_list[1]
|
label_list[1], label_list[2] = label_list[2], label_list[1]
|
||||||
|
|
Загрузка…
Ссылка в новой задаче