diff --git a/README.md b/README.md index 178463e..264b729 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,8 @@ +*WORK IN PROGRESS ...* + The implementation of paper [**UniVL: A Unified Video and Language Pre-Training Model for Multimodal Understanding and Generation**](https://arxiv.org/abs/2002.06353). -UniVL is a **video-language pretrain model**. It is designed with four modules and five objectives for both video language understanding and generation tasks.It is also a flexible model for most of the multimodal downstream tasks considering both efficiency and effectiveness. +UniVL is a **video-language pretrain model**. It is designed with four modules and five objectives for both video language understanding and generation tasks. It is also a flexible model for most of the multimodal downstream tasks considering both efficiency and effectiveness. # Preliminary Excute below scripts in the main folder firstly. diff --git a/modules/modeling.py b/modules/modeling.py index d7902f1..ab169db 100644 --- a/modules/modeling.py +++ b/modules/modeling.py @@ -159,11 +159,12 @@ class UniVL(UniVLPreTrainedModel): self.decoder = DecoderModel(decoder_config, bert_word_embeddings_weight, bert_position_embeddings_weight) # <=== End of Decoder - self.cls = BertOnlyMLMHead(bert_config, bert_word_embeddings_weight) - self.cls_visual = VisualOnlyMLMHead(visual_config, visual_word_embeddings_weight) - + if self.task_config.do_pretrain: + self.cls = BertOnlyMLMHead(bert_config, bert_word_embeddings_weight) + self.cls_visual = VisualOnlyMLMHead(visual_config, visual_word_embeddings_weight) + self.alm_loss_fct = CrossEntropyLoss(ignore_index=-1) + self.similarity_dense = nn.Linear(bert_config.hidden_size, 1) - self.alm_loss_fct = CrossEntropyLoss(ignore_index=-1) self.decoder_loss_fct = CrossEntropyLoss(ignore_index=-1) self.normalize_video = NormalizeVideo(task_config) @@ -209,8 +210,7 @@ class UniVL(UniVLPreTrainedModel): sim_loss = self.loss_fct(sim_matrix) loss += sim_loss - if self._stage_two and pairs_masked_text is not None and pairs_token_labels is not None: - + if self._stage_two: if self.task_config.do_pretrain: pairs_masked_text = pairs_masked_text.view(-1, pairs_masked_text.shape[-1]) pairs_token_labels = pairs_token_labels.view(-1, pairs_token_labels.shape[-1]) @@ -379,18 +379,14 @@ class UniVL(UniVLPreTrainedModel): attention_mask = attention_mask.view(-1, attention_mask.shape[-1]) video_mask = video_mask.view(-1, video_mask.shape[-1]) - if self._stage_two and _pretrain_joint is False: + if (self._stage_two and _pretrain_joint is False) or self.train_sim_after_cross: retrieve_logits = self._cross_similarity(sequence_output, visual_output, attention_mask, video_mask) else: - if self.train_sim_after_cross: - retrieve_logits = self._cross_similarity(sequence_output, visual_output, attention_mask, video_mask) - else: - text_out, video_out = self._mean_pooling_for_similarity(sequence_output, visual_output, attention_mask, video_mask) - # Do a cosine simlarity - if self.task_config.use_mil is False: - text_out = F.normalize(text_out, dim=-1) - video_out = F.normalize(video_out, dim=-1) - retrieve_logits = torch.matmul(text_out, video_out.t()) + text_out, video_out = self._mean_pooling_for_similarity(sequence_output, visual_output, attention_mask, video_mask) + if self.task_config.use_mil is False: + text_out = F.normalize(text_out, dim=-1) + video_out = F.normalize(video_out, dim=-1) + retrieve_logits = torch.matmul(text_out, video_out.t()) return retrieve_logits