Support T5 Distillation w/hidden state supervision (#7599)
This commit is contained in:
Родитель
818c294fdd
Коммит
d5d2744aa7
|
@ -28,7 +28,7 @@ from lightning_base import generic_train # noqa
|
|||
class BartSummarizationDistiller(SummarizationModule):
|
||||
"""Supports Bart, Pegasus and other models that inherit from Bart."""
|
||||
|
||||
loss_names = ["loss", "ce_loss", "mlm_loss", "enc_mse_loss", "hid_loss_enc", "hid_loss_dec"]
|
||||
loss_names = ["loss", "ce_loss", "mlm_loss", "hid_loss_enc", "hid_loss_dec"]
|
||||
|
||||
def __init__(self, hparams):
|
||||
assert Path(hparams.data_dir).exists()
|
||||
|
@ -46,9 +46,19 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||
if hparams.length_penalty != -1:
|
||||
student.config.length_penalty = hparams.length_penalty
|
||||
super().__init__(hparams, model=student, config=student.config)
|
||||
model_type = student.config.model_type
|
||||
self.e_layer_ids, self.d_layer_ids = e_layer_ids, d_layer_ids # type: List[int], List[int]
|
||||
self.different_encoder = hparams.student_encoder_layers != teacher.config.encoder_layers
|
||||
self.different_decoder = hparams.student_decoder_layers != teacher.config.decoder_layers
|
||||
|
||||
if model_type == "t5":
|
||||
teacher_encoder_layers = len(teacher.get_encoder().block)
|
||||
teacher_decoder_layers = len(teacher.get_decoder().block)
|
||||
else:
|
||||
teacher_encoder_layers = teacher.config.encoder_layers
|
||||
teacher_decoder_layers = teacher.config.decoder_layers
|
||||
|
||||
self.different_encoder = hparams.student_encoder_layers != teacher_encoder_layers
|
||||
self.different_decoder = hparams.student_decoder_layers != teacher_decoder_layers
|
||||
|
||||
self.teacher = teacher
|
||||
freeze_params(self.teacher)
|
||||
|
||||
|
@ -59,17 +69,17 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||
del self.teacher.encoder
|
||||
# Intermediate supervision: Decide which layers to supervise
|
||||
if hparams.supervise_forward:
|
||||
self.d_matches = get_layers_to_supervise(
|
||||
n_student=len(self.d_layer_ids), n_teacher=self.teacher.config.decoder_layers
|
||||
)
|
||||
else:
|
||||
self.e_matches = get_layers_to_supervise(n_student=len(self.e_layer_ids), n_teacher=teacher_encoder_layers)
|
||||
self.d_matches = get_layers_to_supervise(n_student=len(self.d_layer_ids), n_teacher=teacher_decoder_layers)
|
||||
else: # student layer should emulate hidden states of the teacher layer it was copied from
|
||||
self.e_matches = self.e_layer_ids
|
||||
self.d_matches = self.d_layer_ids
|
||||
|
||||
self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
|
||||
self.temperature = 2.0
|
||||
self.alpha_mlm = hparams.alpha_mlm
|
||||
self.alpha_ce = hparams.alpha_ce
|
||||
self.alpha_hid = hparams.alpha_hid
|
||||
self.alpha_encoder_loss = hparams.alpha_encoder_loss
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
@ -129,7 +139,7 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||
output_hidden_states=True,
|
||||
output_attentions=False,
|
||||
use_cache=False,
|
||||
) # TODO(@sshleifer): return_dict=True cleanup
|
||||
)
|
||||
|
||||
# Same cross entropy vs. label smoothing logic as finetune.py
|
||||
assert lm_logits.shape[-1] == self.model.config.vocab_size
|
||||
|
@ -146,30 +156,32 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||
def zero_tensor():
|
||||
return torch.tensor(0.0).type_as(student_lm_loss)
|
||||
|
||||
loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(), zero_tensor()
|
||||
if self.different_encoder:
|
||||
hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor()
|
||||
if self.different_encoder: # compute encoder hidden state loss
|
||||
with torch.no_grad():
|
||||
teacher_enc_outputs, teacher_enc_hid, _ = self.teacher.get_encoder()(
|
||||
input_ids, attention_mask=src_mask, output_hidden_states=True
|
||||
)
|
||||
# DEPRECATE THIS
|
||||
if self.hparams.alpha_encoder_loss > 0:
|
||||
loss_encoder = self.calc_mse_loss(enc_outputs, teacher_enc_outputs, src_mask)
|
||||
teacher_enc_hid = self.teacher.get_encoder()(
|
||||
input_ids, attention_mask=src_mask, output_hidden_states=True, return_dict=True
|
||||
).hidden_states
|
||||
|
||||
hid_loss_enc = self.calc_hidden_loss(src_mask, enc_hidden_state, teacher_enc_hid, self.e_layer_ids)
|
||||
|
||||
teacher_enc_outputs = (enc_outputs,)
|
||||
assert isinstance(teacher_enc_outputs, tuple), type(teacher_enc_outputs)
|
||||
hid_loss_enc = self.calc_hidden_loss(
|
||||
src_mask,
|
||||
enc_hidden_state,
|
||||
teacher_enc_hid,
|
||||
self.e_matches,
|
||||
normalize_hidden=self.hparams.normalize_hidden,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
tloss, tlogits, tdec_hidden, _ = self.teacher(
|
||||
outputs = self.teacher(
|
||||
input_ids,
|
||||
attention_mask=src_mask,
|
||||
encoder_outputs=teacher_enc_outputs,
|
||||
encoder_outputs=(enc_outputs,),
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
lm_labels=labels,
|
||||
output_hidden_states=True,
|
||||
return_dict=True,
|
||||
)
|
||||
tlogits, tdec_hidden = outputs.logits, outputs.decoder_hidden_states
|
||||
dec_mask = decoder_input_ids.ne(pad_token_id)
|
||||
loss_ce = self.calc_ce_loss(dec_mask, lm_logits, tlogits)
|
||||
if self.alpha_hid > 0: # Intermediate supervision of decoder hidden states
|
||||
|
@ -180,10 +192,9 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||
blended_loss = (
|
||||
self.alpha_ce * loss_ce
|
||||
+ self.alpha_mlm * student_lm_loss
|
||||
+ self.hparams.alpha_encoder_loss * loss_encoder
|
||||
+ self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec)
|
||||
)
|
||||
return blended_loss, loss_ce, student_lm_loss, loss_encoder, hid_loss_enc, hid_loss_dec
|
||||
return blended_loss, loss_ce, student_lm_loss, hid_loss_enc, hid_loss_dec
|
||||
|
||||
@staticmethod
|
||||
def calc_hidden_loss(attention_mask, hidden_states, hidden_states_T, matches, normalize_hidden):
|
||||
|
@ -207,7 +218,6 @@ def add_distill_args(parser):
|
|||
parser.add_argument("--teacher", type=str)
|
||||
parser.add_argument("--alpha_ce", default=0.8, type=float)
|
||||
parser.add_argument("--alpha_mlm", default=0.2, type=float)
|
||||
parser.add_argument("--alpha_encoder_loss", default=0.0, type=float)
|
||||
parser.add_argument("--alpha_hid", default=0.0, type=float, required=False)
|
||||
parser.add_argument("--student_decoder_layers", default=12, type=int, required=False)
|
||||
parser.add_argument("--student_encoder_layers", default=12, type=int, required=False)
|
||||
|
|
|
@ -86,7 +86,6 @@ CHEAP_ARGS = {
|
|||
"n_val": -1,
|
||||
"n_test": -1,
|
||||
"student_encoder_layers": 1,
|
||||
"alpha_encoder_loss": 0.0,
|
||||
"freeze_encoder": False,
|
||||
"auto_scale_batch_size": False,
|
||||
}
|
||||
|
@ -230,7 +229,6 @@ class TestSummarizationDistiller(unittest.TestCase):
|
|||
|
||||
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
|
||||
|
||||
@unittest.skip("T5 distillation is broken at the moment")
|
||||
def test_distill_t5(self):
|
||||
updates = dict(
|
||||
student_encoder_layers=1,
|
||||
|
@ -255,7 +253,6 @@ class TestSummarizationDistiller(unittest.TestCase):
|
|||
model_name_or_path="sshleifer/tinier_bart",
|
||||
teacher=CHEAP_ARGS["model_name_or_path"],
|
||||
val_check_interval=0.5,
|
||||
alpha_encoder_loss=0.4,
|
||||
)
|
||||
default_updates.update(updates)
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
|
|
Загрузка…
Ссылка в новой задаче