support append tokens to vocab
This commit is contained in:
Родитель
a6b24b0cbf
Коммит
60871a24d3
|
@ -154,6 +154,31 @@ class NgramTransformerProphetModel(FairseqEncoderDecoderModel):
|
|||
decoder_token_weight[2] = decoder_token_weight[102]
|
||||
states['encoder.embed_tokens.weight'] = encoder_token_weight
|
||||
states['decoder.embed_tokens.weight'] = decoder_token_weight
|
||||
|
||||
loaded_dict_size = states['encoder.embed_tokens.weight'].size(0)
|
||||
num_langids_to_add = len(encoder.dictionary) - loaded_dict_size
|
||||
embed_dim = states['encoder.embed_tokens.weight'].size(1)
|
||||
|
||||
if num_langids_to_add > 0:
|
||||
new_lang_embed_to_add = torch.zeros(num_langids_to_add, embed_dim)
|
||||
nn.init.normal_(
|
||||
new_lang_embed_to_add,
|
||||
mean=0,
|
||||
std=embed_dim ** -0.5
|
||||
)
|
||||
new_lang_embed_to_add = new_lang_embed_to_add.to(
|
||||
dtype=states['encoder.embed_tokens.weight'].dtype,
|
||||
)
|
||||
|
||||
states['encoder.embed_tokens.weight'] = torch.cat([
|
||||
states['encoder.embed_tokens.weight'],
|
||||
new_lang_embed_to_add]
|
||||
)
|
||||
states['decoder.embed_tokens.weight'] = torch.cat([
|
||||
states['decoder.embed_tokens.weight'],
|
||||
new_lang_embed_to_add]
|
||||
)
|
||||
|
||||
for position_name, target_position_length in [('encoder.embed_positions.weight', model.encoder.embed_positions.weight.size(0)), \
|
||||
('decoder.embed_positions.weight', model.decoder.embed_positions.weight.size(0))]:
|
||||
if states[position_name].size(0) < target_position_length:
|
||||
|
|
Загрузка…
Ссылка в новой задаче