support append tokens to vocab

This commit is contained in:
Yu Yan 2020-06-10 15:34:55 -07:00 коммит произвёл GitHub
Родитель a6b24b0cbf
Коммит 60871a24d3
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 25 добавлений и 0 удалений

Просмотреть файл

@ -154,6 +154,31 @@ class NgramTransformerProphetModel(FairseqEncoderDecoderModel):
decoder_token_weight[2] = decoder_token_weight[102]
states['encoder.embed_tokens.weight'] = encoder_token_weight
states['decoder.embed_tokens.weight'] = decoder_token_weight
loaded_dict_size = states['encoder.embed_tokens.weight'].size(0)
num_langids_to_add = len(encoder.dictionary) - loaded_dict_size
embed_dim = states['encoder.embed_tokens.weight'].size(1)
if num_langids_to_add > 0:
new_lang_embed_to_add = torch.zeros(num_langids_to_add, embed_dim)
nn.init.normal_(
new_lang_embed_to_add,
mean=0,
std=embed_dim ** -0.5
)
new_lang_embed_to_add = new_lang_embed_to_add.to(
dtype=states['encoder.embed_tokens.weight'].dtype,
)
states['encoder.embed_tokens.weight'] = torch.cat([
states['encoder.embed_tokens.weight'],
new_lang_embed_to_add]
)
states['decoder.embed_tokens.weight'] = torch.cat([
states['decoder.embed_tokens.weight'],
new_lang_embed_to_add]
)
for position_name, target_position_length in [('encoder.embed_positions.weight', model.encoder.embed_positions.weight.size(0)), \
('decoder.embed_positions.weight', model.decoder.embed_positions.weight.size(0))]:
if states[position_name].size(0) < target_position_length: