Support T5 Distillation w/hidden state supervision (#7599)

This commit is contained in:
Sam Shleifer 2020-10-05 21:31:48 -04:00 коммит произвёл GitHub
Родитель 818c294fdd
Коммит d5d2744aa7
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 36 добавлений и 29 удалений

Просмотреть файл

@ -28,7 +28,7 @@ from lightning_base import generic_train # noqa
class BartSummarizationDistiller(SummarizationModule):
"""Supports Bart, Pegasus and other models that inherit from Bart."""
loss_names = ["loss", "ce_loss", "mlm_loss", "enc_mse_loss", "hid_loss_enc", "hid_loss_dec"]
loss_names = ["loss", "ce_loss", "mlm_loss", "hid_loss_enc", "hid_loss_dec"]
def __init__(self, hparams):
assert Path(hparams.data_dir).exists()
@ -46,9 +46,19 @@ class BartSummarizationDistiller(SummarizationModule):
if hparams.length_penalty != -1:
student.config.length_penalty = hparams.length_penalty
super().__init__(hparams, model=student, config=student.config)
model_type = student.config.model_type
self.e_layer_ids, self.d_layer_ids = e_layer_ids, d_layer_ids # type: List[int], List[int]
self.different_encoder = hparams.student_encoder_layers != teacher.config.encoder_layers
self.different_decoder = hparams.student_decoder_layers != teacher.config.decoder_layers
if model_type == "t5":
teacher_encoder_layers = len(teacher.get_encoder().block)
teacher_decoder_layers = len(teacher.get_decoder().block)
else:
teacher_encoder_layers = teacher.config.encoder_layers
teacher_decoder_layers = teacher.config.decoder_layers
self.different_encoder = hparams.student_encoder_layers != teacher_encoder_layers
self.different_decoder = hparams.student_decoder_layers != teacher_decoder_layers
self.teacher = teacher
freeze_params(self.teacher)
@ -59,17 +69,17 @@ class BartSummarizationDistiller(SummarizationModule):
del self.teacher.encoder
# Intermediate supervision: Decide which layers to supervise
if hparams.supervise_forward:
self.d_matches = get_layers_to_supervise(
n_student=len(self.d_layer_ids), n_teacher=self.teacher.config.decoder_layers
)
else:
self.e_matches = get_layers_to_supervise(n_student=len(self.e_layer_ids), n_teacher=teacher_encoder_layers)
self.d_matches = get_layers_to_supervise(n_student=len(self.d_layer_ids), n_teacher=teacher_decoder_layers)
else: # student layer should emulate hidden states of the teacher layer it was copied from
self.e_matches = self.e_layer_ids
self.d_matches = self.d_layer_ids
self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
self.temperature = 2.0
self.alpha_mlm = hparams.alpha_mlm
self.alpha_ce = hparams.alpha_ce
self.alpha_hid = hparams.alpha_hid
self.alpha_encoder_loss = hparams.alpha_encoder_loss
gc.collect()
torch.cuda.empty_cache()
@ -129,7 +139,7 @@ class BartSummarizationDistiller(SummarizationModule):
output_hidden_states=True,
output_attentions=False,
use_cache=False,
) # TODO(@sshleifer): return_dict=True cleanup
)
# Same cross entropy vs. label smoothing logic as finetune.py
assert lm_logits.shape[-1] == self.model.config.vocab_size
@ -146,30 +156,32 @@ class BartSummarizationDistiller(SummarizationModule):
def zero_tensor():
return torch.tensor(0.0).type_as(student_lm_loss)
loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(), zero_tensor()
if self.different_encoder:
hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor()
if self.different_encoder: # compute encoder hidden state loss
with torch.no_grad():
teacher_enc_outputs, teacher_enc_hid, _ = self.teacher.get_encoder()(
input_ids, attention_mask=src_mask, output_hidden_states=True
)
# DEPRECATE THIS
if self.hparams.alpha_encoder_loss > 0:
loss_encoder = self.calc_mse_loss(enc_outputs, teacher_enc_outputs, src_mask)
teacher_enc_hid = self.teacher.get_encoder()(
input_ids, attention_mask=src_mask, output_hidden_states=True, return_dict=True
).hidden_states
hid_loss_enc = self.calc_hidden_loss(src_mask, enc_hidden_state, teacher_enc_hid, self.e_layer_ids)
teacher_enc_outputs = (enc_outputs,)
assert isinstance(teacher_enc_outputs, tuple), type(teacher_enc_outputs)
hid_loss_enc = self.calc_hidden_loss(
src_mask,
enc_hidden_state,
teacher_enc_hid,
self.e_matches,
normalize_hidden=self.hparams.normalize_hidden,
)
with torch.no_grad():
tloss, tlogits, tdec_hidden, _ = self.teacher(
outputs = self.teacher(
input_ids,
attention_mask=src_mask,
encoder_outputs=teacher_enc_outputs,
encoder_outputs=(enc_outputs,),
decoder_input_ids=decoder_input_ids,
lm_labels=labels,
output_hidden_states=True,
return_dict=True,
)
tlogits, tdec_hidden = outputs.logits, outputs.decoder_hidden_states
dec_mask = decoder_input_ids.ne(pad_token_id)
loss_ce = self.calc_ce_loss(dec_mask, lm_logits, tlogits)
if self.alpha_hid > 0: # Intermediate supervision of decoder hidden states
@ -180,10 +192,9 @@ class BartSummarizationDistiller(SummarizationModule):
blended_loss = (
self.alpha_ce * loss_ce
+ self.alpha_mlm * student_lm_loss
+ self.hparams.alpha_encoder_loss * loss_encoder
+ self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec)
)
return blended_loss, loss_ce, student_lm_loss, loss_encoder, hid_loss_enc, hid_loss_dec
return blended_loss, loss_ce, student_lm_loss, hid_loss_enc, hid_loss_dec
@staticmethod
def calc_hidden_loss(attention_mask, hidden_states, hidden_states_T, matches, normalize_hidden):
@ -207,7 +218,6 @@ def add_distill_args(parser):
parser.add_argument("--teacher", type=str)
parser.add_argument("--alpha_ce", default=0.8, type=float)
parser.add_argument("--alpha_mlm", default=0.2, type=float)
parser.add_argument("--alpha_encoder_loss", default=0.0, type=float)
parser.add_argument("--alpha_hid", default=0.0, type=float, required=False)
parser.add_argument("--student_decoder_layers", default=12, type=int, required=False)
parser.add_argument("--student_encoder_layers", default=12, type=int, required=False)

Просмотреть файл

@ -86,7 +86,6 @@ CHEAP_ARGS = {
"n_val": -1,
"n_test": -1,
"student_encoder_layers": 1,
"alpha_encoder_loss": 0.0,
"freeze_encoder": False,
"auto_scale_batch_size": False,
}
@ -230,7 +229,6 @@ class TestSummarizationDistiller(unittest.TestCase):
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
@unittest.skip("T5 distillation is broken at the moment")
def test_distill_t5(self):
updates = dict(
student_encoder_layers=1,
@ -255,7 +253,6 @@ class TestSummarizationDistiller(unittest.TestCase):
model_name_or_path="sshleifer/tinier_bart",
teacher=CHEAP_ARGS["model_name_or_path"],
val_check_interval=0.5,
alpha_encoder_loss=0.4,
)
default_updates.update(updates)
args_d: dict = CHEAP_ARGS.copy()