diff --git a/src/transformers/data/datasets/glue.py b/src/transformers/data/datasets/glue.py index 1775b93f9..f696f6824 100644 --- a/src/transformers/data/datasets/glue.py +++ b/src/transformers/data/datasets/glue.py @@ -9,6 +9,7 @@ import torch from filelock import FileLock from torch.utils.data.dataset import Dataset +from ...tokenization_bart import BartTokenizer, BartTokenizerFast from ...tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast from ...tokenization_utils import PreTrainedTokenizer from ...tokenization_xlm_roberta import XLMRobertaTokenizer @@ -92,6 +93,8 @@ class GlueDataset(Dataset): RobertaTokenizer, RobertaTokenizerFast, XLMRobertaTokenizer, + BartTokenizer, + BartTokenizerFast, ): # HACK(label indices are swapped in RoBERTa pretrained model) label_list[1], label_list[2] = label_list[2], label_list[1]