fix bugs for finetuning unispeech

This commit is contained in:
cywang 2022-04-07 04:21:27 -07:00
Родитель 295c961c0c
Коммит e3043e2021
2 изменённых файлов: 8 добавлений и 2 удалений

Просмотреть файл

@ -63,6 +63,9 @@ class Unispeech(BaseFairseqModel):
x = self.w2v_encoder(**kwargs)
return x
def remove_pretraining_modules(self):
self.w2v_encoder.proj = None
class Wav2VecEncoder(FairseqEncoder):
def __init__(self, cfg, task):
super().__init__(task.source_dictionary)

Просмотреть файл

@ -353,16 +353,19 @@ class Wav2VecEncoder(FairseqEncoder):
task = tasks.setup_task(w2v_args.task)
model = task.build_model(w2v_args.model)
model.remove_pretraining_modules()
if state is not None and not cfg.no_pretrained_weights:
model.load_state_dict(state["model"], strict=True)
model.remove_pretraining_modules()
super().__init__(task.source_dictionary)
d = w2v_args.model.encoder_embed_dim
self.w2v_model = model
if hasattr(model, 'w2v_encoder'):
self.w2v_model = model.w2v_encoder.w2v_model
else:
self.w2v_model = model
self.final_dropout = nn.Dropout(cfg.final_dropout)
self.freeze_finetune_updates = cfg.freeze_finetune_updates