Zero shot distillation script cuda patch (#10284)

This commit is contained in:
Joe Davison 2021-02-19 14:06:57 -05:00 коммит произвёл GitHub
Родитель f1299f5038
Коммит cbadb5243c
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 1 добавлений и 1 удалений

Просмотреть файл

@ -174,7 +174,7 @@ def get_teacher_predictions(
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model_config = model.config
if not no_cuda and torch.cuda.is_available():
model = nn.DataParallel(model)
model = nn.DataParallel(model.cuda())
batch_size *= len(model.device_ids)
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=use_fast_tokenizer)