This commit is contained in:
Родитель
e168488a74
Коммит
6e16195510
|
@ -1045,14 +1045,13 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
|||
|
||||
last_hidden = transformer_outputs[0]
|
||||
pred_hid = last_hidden[:, -tgt_len:]
|
||||
outputs = transformer_outputs[1:]
|
||||
|
||||
softmax_output = self.crit(pred_hid, labels)
|
||||
prediction_scores = softmax_output.view(bsz, tgt_len, -1) if labels is None else ()
|
||||
loss = softmax_output.view(bsz, tgt_len - 1) if labels is not None else None
|
||||
|
||||
if return_tuple:
|
||||
output = (prediction_scores,) + outputs[1:]
|
||||
output = (prediction_scores,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TransfoXLLMHeadModelOutput(
|
||||
|
|
Загрузка…
Ссылка в новой задаче