diff --git a/MASS-supNMT/mass/xmasked_seq2seq.py b/MASS-supNMT/mass/xmasked_seq2seq.py index ea030a0..6cfbad2 100644 --- a/MASS-supNMT/mass/xmasked_seq2seq.py +++ b/MASS-supNMT/mass/xmasked_seq2seq.py @@ -427,6 +427,9 @@ class XMassTranslationTask(FairseqTask): return agg_loss, agg_sample_size, agg_logging_output def inference_step(self, generator, models, sample, prefix_tokens=None): + for model in models: + model.source_lang = self.args.source_lang + model.target_lang = self.args.target_lang with torch.no_grad(): return generator.generate( models,