This commit is contained in:
Sylvain Gugger 2020-07-23 13:51:29 -04:00 коммит произвёл GitHub
Родитель e168488a74
Коммит 6e16195510
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 1 добавлений и 2 удалений

Просмотреть файл

@ -1045,14 +1045,13 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
last_hidden = transformer_outputs[0] last_hidden = transformer_outputs[0]
pred_hid = last_hidden[:, -tgt_len:] pred_hid = last_hidden[:, -tgt_len:]
outputs = transformer_outputs[1:]
softmax_output = self.crit(pred_hid, labels) softmax_output = self.crit(pred_hid, labels)
prediction_scores = softmax_output.view(bsz, tgt_len, -1) if labels is None else () prediction_scores = softmax_output.view(bsz, tgt_len, -1) if labels is None else ()
loss = softmax_output.view(bsz, tgt_len - 1) if labels is not None else None loss = softmax_output.view(bsz, tgt_len - 1) if labels is not None else None
if return_tuple: if return_tuple:
output = (prediction_scores,) + outputs[1:] output = (prediction_scores,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return TransfoXLLMHeadModelOutput( return TransfoXLLMHeadModelOutput(