This commit is contained in:
ArrowLuo 2021-01-29 12:30:27 +08:00
Родитель 962eb07554
Коммит 3194243283
2 изменённых файлов: 15 добавлений и 17 удалений

Просмотреть файл

@ -1,6 +1,8 @@
*WORK IN PROGRESS ...*
The implementation of paper [**UniVL: A Unified Video and Language Pre-Training Model for Multimodal Understanding and Generation**](https://arxiv.org/abs/2002.06353).
UniVL is a **video-language pretrain model**. It is designed with four modules and five objectives for both video language understanding and generation tasks.It is also a flexible model for most of the multimodal downstream tasks considering both efficiency and effectiveness.
UniVL is a **video-language pretrain model**. It is designed with four modules and five objectives for both video language understanding and generation tasks. It is also a flexible model for most of the multimodal downstream tasks considering both efficiency and effectiveness.
# Preliminary
Excute below scripts in the main folder firstly.

Просмотреть файл

@ -159,11 +159,12 @@ class UniVL(UniVLPreTrainedModel):
self.decoder = DecoderModel(decoder_config, bert_word_embeddings_weight, bert_position_embeddings_weight)
# <=== End of Decoder
if self.task_config.do_pretrain:
self.cls = BertOnlyMLMHead(bert_config, bert_word_embeddings_weight)
self.cls_visual = VisualOnlyMLMHead(visual_config, visual_word_embeddings_weight)
self.alm_loss_fct = CrossEntropyLoss(ignore_index=-1)
self.similarity_dense = nn.Linear(bert_config.hidden_size, 1)
self.alm_loss_fct = CrossEntropyLoss(ignore_index=-1)
self.decoder_loss_fct = CrossEntropyLoss(ignore_index=-1)
self.normalize_video = NormalizeVideo(task_config)
@ -209,8 +210,7 @@ class UniVL(UniVLPreTrainedModel):
sim_loss = self.loss_fct(sim_matrix)
loss += sim_loss
if self._stage_two and pairs_masked_text is not None and pairs_token_labels is not None:
if self._stage_two:
if self.task_config.do_pretrain:
pairs_masked_text = pairs_masked_text.view(-1, pairs_masked_text.shape[-1])
pairs_token_labels = pairs_token_labels.view(-1, pairs_token_labels.shape[-1])
@ -379,14 +379,10 @@ class UniVL(UniVLPreTrainedModel):
attention_mask = attention_mask.view(-1, attention_mask.shape[-1])
video_mask = video_mask.view(-1, video_mask.shape[-1])
if self._stage_two and _pretrain_joint is False:
retrieve_logits = self._cross_similarity(sequence_output, visual_output, attention_mask, video_mask)
else:
if self.train_sim_after_cross:
if (self._stage_two and _pretrain_joint is False) or self.train_sim_after_cross:
retrieve_logits = self._cross_similarity(sequence_output, visual_output, attention_mask, video_mask)
else:
text_out, video_out = self._mean_pooling_for_similarity(sequence_output, visual_output, attention_mask, video_mask)
# Do a cosine simlarity
if self.task_config.use_mil is False:
text_out = F.normalize(text_out, dim=-1)
video_out = F.normalize(video_out, dim=-1)