This commit is contained in:
Sanxing Chen 2021-11-01 03:24:13 +00:00
Родитель eedd02a214
Коммит b508106881
2 изменённых файлов: 4 добавлений и 1 удалений

Просмотреть файл

@ -230,7 +230,7 @@ class CrossTrmFinetuner(pl.LightningModule):
if self._hparams['use_hitter']:
# kg_masks: [bs, 1, 1, length]
# kg_embeds: nlayer*[bs, length, dim]
kg_embeds, kg_masks = self.hitter('get_hitter_repr', s=s, p=p)
kg_embeds, kg_masks = self.hitter('get_hitter_repr', s, p)
kg_attentions = [None] * 2 + [(self.cross_attentions[i], kg_embeds[(i + 2) // 2], kg_masks)
for i in range(self.kg_layer_num)]
else:

Просмотреть файл

@ -229,6 +229,9 @@ class TrmE(KgeModel):
del self._scorer._entity_embedder
del self._scorer._relation_embedder
if fn_name == 'get_hitter_repr':
return scores
if self.training:
self_loss_w = self.get_option("self_dropout")
# MLM-like loss is weighted by the proportion of entities sampled