зеркало из https://github.com/microsoft/UniSpeech.git
fix bugs for finetuning unispeech
This commit is contained in:
Родитель
295c961c0c
Коммит
e3043e2021
|
@ -63,6 +63,9 @@ class Unispeech(BaseFairseqModel):
|
|||
x = self.w2v_encoder(**kwargs)
|
||||
return x
|
||||
|
||||
def remove_pretraining_modules(self):
|
||||
self.w2v_encoder.proj = None
|
||||
|
||||
class Wav2VecEncoder(FairseqEncoder):
|
||||
def __init__(self, cfg, task):
|
||||
super().__init__(task.source_dictionary)
|
||||
|
|
|
@ -353,16 +353,19 @@ class Wav2VecEncoder(FairseqEncoder):
|
|||
task = tasks.setup_task(w2v_args.task)
|
||||
model = task.build_model(w2v_args.model)
|
||||
|
||||
model.remove_pretraining_modules()
|
||||
if state is not None and not cfg.no_pretrained_weights:
|
||||
model.load_state_dict(state["model"], strict=True)
|
||||
|
||||
model.remove_pretraining_modules()
|
||||
|
||||
super().__init__(task.source_dictionary)
|
||||
|
||||
d = w2v_args.model.encoder_embed_dim
|
||||
|
||||
self.w2v_model = model
|
||||
if hasattr(model, 'w2v_encoder'):
|
||||
self.w2v_model = model.w2v_encoder.w2v_model
|
||||
else:
|
||||
self.w2v_model = model
|
||||
|
||||
self.final_dropout = nn.Dropout(cfg.final_dropout)
|
||||
self.freeze_finetune_updates = cfg.freeze_finetune_updates
|
||||
|
|
Загрузка…
Ссылка в новой задаче