2020-09-22 19:29:58 +03:00
# coding=utf-8
# Copyright 2020, The RAG Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import shutil
import tempfile
import unittest
from unittest . mock import patch
import numpy as np
2020-10-18 21:51:24 +03:00
from transformers import BartTokenizer , T5Tokenizer
2020-09-22 19:29:58 +03:00
from transformers . file_utils import cached_property , is_datasets_available , is_faiss_available , is_torch_available
2020-11-17 05:43:42 +03:00
from transformers . models . bert . tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
from transformers . models . dpr . tokenization_dpr import DPRQuestionEncoderTokenizer
from transformers . models . roberta . tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
2020-10-30 17:25:48 +03:00
from transformers . testing_utils import (
require_sentencepiece ,
require_tokenizers ,
require_torch ,
2020-11-10 21:23:58 +03:00
require_torch_non_multi_gpu ,
2020-10-30 17:25:48 +03:00
slow ,
torch_device ,
)
2020-09-22 19:29:58 +03:00
2020-12-09 22:55:24 +03:00
from . test_modeling_bart import BartModelTester
2020-09-22 19:29:58 +03:00
from . test_modeling_dpr import DPRModelTester
from . test_modeling_t5 import T5ModelTester
TOLERANCE = 1e-3
T5_SAMPLE_VOCAB = os . path . join ( os . path . dirname ( os . path . abspath ( __file__ ) ) , " fixtures/test_sentencepiece.model " )
if is_torch_available ( ) and is_datasets_available ( ) and is_faiss_available ( ) :
import torch
from datasets import Dataset
import faiss
from transformers import (
AutoConfig ,
AutoModel ,
AutoModelForSeq2SeqLM ,
RagConfig ,
RagModel ,
RagRetriever ,
RagSequenceForGeneration ,
RagTokenForGeneration ,
2020-09-25 17:12:46 +03:00
RagTokenizer ,
2020-09-22 19:29:58 +03:00
)
from transformers . modeling_outputs import BaseModelOutput
def _assert_tensors_equal ( a , b , atol = 1e-12 , prefix = " " ) :
""" If tensors not close, or a and b arent both tensors, raise a nice Assertion error. """
if a is None and b is None :
return True
try :
if torch . allclose ( a , b , atol = atol ) :
return True
raise
except Exception :
2021-03-31 17:00:27 +03:00
msg = f " { a } != { b } "
2020-09-22 19:29:58 +03:00
if prefix :
msg = prefix + " : " + msg
raise AssertionError ( msg )
def require_retrieval ( test_case ) :
"""
Decorator marking a test that requires a set of dependencies necessary for pefrorm retrieval with
: class : ` ~ transformers . RagRetriever ` .
These tests are skipped when respective libraries are not installed .
"""
if not ( is_torch_available ( ) and is_datasets_available ( ) and is_faiss_available ( ) ) :
2020-10-19 16:15:52 +03:00
test_case = unittest . skip ( " test requires PyTorch, datasets and faiss " ) ( test_case )
2020-09-22 19:29:58 +03:00
return test_case
@require_torch
@require_retrieval
2020-10-18 21:51:24 +03:00
@require_sentencepiece
2020-09-22 19:29:58 +03:00
class RagTestMixin :
all_model_classes = (
( RagModel , RagTokenForGeneration , RagSequenceForGeneration )
if is_torch_available ( ) and is_datasets_available ( ) and is_faiss_available ( )
else ( )
)
retrieval_vector_size = 32
2020-10-19 16:15:52 +03:00
n_docs = 3
2020-09-22 19:29:58 +03:00
max_combined_length = 16
def setUp ( self ) :
self . tmpdirname = tempfile . mkdtemp ( )
# DPR tok
vocab_tokens = [
" [UNK] " ,
" [CLS] " ,
" [SEP] " ,
" [PAD] " ,
" [MASK] " ,
" want " ,
" ##want " ,
" ##ed " ,
" wa " ,
" un " ,
" runn " ,
" ##ing " ,
" , " ,
" low " ,
" lowest " ,
]
dpr_tokenizer_path = os . path . join ( self . tmpdirname , " dpr_tokenizer " )
os . makedirs ( dpr_tokenizer_path , exist_ok = True )
self . vocab_file = os . path . join ( dpr_tokenizer_path , DPR_VOCAB_FILES_NAMES [ " vocab_file " ] )
with open ( self . vocab_file , " w " , encoding = " utf-8 " ) as vocab_writer :
vocab_writer . write ( " " . join ( [ x + " \n " for x in vocab_tokens ] ) )
# BART tok
vocab = [
" l " ,
" o " ,
" w " ,
" e " ,
" r " ,
" s " ,
" t " ,
" i " ,
" d " ,
" n " ,
" \u0120 " ,
" \u0120 l " ,
" \u0120 n " ,
" \u0120 lo " ,
" \u0120 low " ,
" er " ,
" \u0120 lowest " ,
" \u0120 newer " ,
" \u0120 wider " ,
" <unk> " ,
]
vocab_tokens = dict ( zip ( vocab , range ( len ( vocab ) ) ) )
merges = [ " #version: 0.2 " , " \u0120 l " , " \u0120 l o " , " \u0120 lo w " , " e r " , " " ]
self . special_tokens_map = { " unk_token " : " <unk> " }
bart_tokenizer_path = os . path . join ( self . tmpdirname , " bart_tokenizer " )
os . makedirs ( bart_tokenizer_path , exist_ok = True )
self . vocab_file = os . path . join ( bart_tokenizer_path , BART_VOCAB_FILES_NAMES [ " vocab_file " ] )
self . merges_file = os . path . join ( bart_tokenizer_path , BART_VOCAB_FILES_NAMES [ " merges_file " ] )
with open ( self . vocab_file , " w " , encoding = " utf-8 " ) as fp :
fp . write ( json . dumps ( vocab_tokens ) + " \n " )
with open ( self . merges_file , " w " , encoding = " utf-8 " ) as fp :
fp . write ( " \n " . join ( merges ) )
t5_tokenizer = T5Tokenizer ( T5_SAMPLE_VOCAB )
t5_tokenizer_path = os . path . join ( self . tmpdirname , " t5_tokenizer " )
t5_tokenizer . save_pretrained ( t5_tokenizer_path )
@cached_property
def dpr_tokenizer ( self ) - > DPRQuestionEncoderTokenizer :
return DPRQuestionEncoderTokenizer . from_pretrained ( os . path . join ( self . tmpdirname , " dpr_tokenizer " ) )
@cached_property
def bart_tokenizer ( self ) - > BartTokenizer :
return BartTokenizer . from_pretrained ( os . path . join ( self . tmpdirname , " bart_tokenizer " ) )
@cached_property
def t5_tokenizer ( self ) - > BartTokenizer :
return T5Tokenizer . from_pretrained ( os . path . join ( self . tmpdirname , " t5_tokenizer " ) )
def tearDown ( self ) :
shutil . rmtree ( self . tmpdirname )
def get_retriever ( self , config ) :
dataset = Dataset . from_dict (
{
2020-10-19 16:15:52 +03:00
" id " : [ " 0 " , " 1 " , " 3 " ] ,
" text " : [ " foo " , " bar " , " qux " ] ,
" title " : [ " Foo " , " Bar " , " Qux " ] ,
" embeddings " : [
np . ones ( self . retrieval_vector_size ) ,
2 * np . ones ( self . retrieval_vector_size ) ,
3 * np . ones ( self . retrieval_vector_size ) ,
] ,
2020-09-22 19:29:58 +03:00
}
)
dataset . add_faiss_index ( " embeddings " , string_factory = " Flat " , metric_type = faiss . METRIC_INNER_PRODUCT )
tokenizer = self . bart_tokenizer if config . generator . model_type == " bart " else self . t5_tokenizer
2020-11-17 05:43:42 +03:00
with patch ( " transformers.models.rag.retrieval_rag.load_dataset " ) as mock_load_dataset :
2020-09-22 19:29:58 +03:00
mock_load_dataset . return_value = dataset
retriever = RagRetriever (
config ,
question_encoder_tokenizer = self . dpr_tokenizer ,
generator_tokenizer = tokenizer ,
)
return retriever
def check_model_with_retriever (
self , config , input_ids , attention_mask , decoder_input_ids , decoder_attention_mask , * * kwargs
) :
self . assertIsNotNone ( config . question_encoder )
self . assertIsNotNone ( config . generator )
for model_class in self . all_model_classes :
model = model_class ( config , retriever = self . get_retriever ( config ) ) . to ( torch_device )
model . eval ( )
self . assertTrue ( model . config . is_encoder_decoder )
outputs = model (
input_ids = input_ids ,
attention_mask = attention_mask ,
decoder_input_ids = decoder_input_ids ,
decoder_attention_mask = decoder_attention_mask ,
)
# logits
self . assertEqual (
outputs . logits . shape ,
( self . n_docs * decoder_input_ids . shape [ 0 ] , decoder_input_ids . shape [ 1 ] , config . generator . vocab_size ) ,
)
# generator encoder last hidden states
self . assertEqual (
outputs . generator_enc_last_hidden_state . shape ,
( self . n_docs * decoder_input_ids . shape [ 0 ] , self . max_combined_length , config . generator . hidden_size ) ,
)
# doc scores
self . assertEqual ( outputs . doc_scores . shape , ( input_ids . shape [ 0 ] , self . n_docs ) )
Proposed Fix : [RagSequenceForGeneration] generate "without" input_ids (#9220)
* Create modeling_tf_dpr.py
* Add TFDPR
* Add back TFPegasus, TFMarian, TFMBart, TFBlenderBot
last commit accidentally deleted these 4 lines, so I recover them back
* Add TFDPR
* Add TFDPR
* clean up some comments, add TF input-style doc string
* Add TFDPR
* Make return_dict=False as default
* Fix return_dict bug (in .from_pretrained)
* Add get_input_embeddings()
* Create test_modeling_tf_dpr.py
The current version is already passed all 27 tests!
Please see the test run at :
https://colab.research.google.com/drive/1czS_m9zy5k-iSJbzA_DP1k1xAAC_sdkf?usp=sharing
* fix quality
* delete init weights
* run fix copies
* fix repo consis
* del config_class, load_tf_weights
They shoud be 'pytorch only'
* add config_class back
after removing it, test failed ... so totally only removing "use_tf_weights = None" on Lysandre suggestion
* newline after .. note::
* import tf, np (Necessary for ModelIntegrationTest)
* slow_test from_pretrained with from_pt=True
At the moment we don't have TF weights (since we don't have official official TF model)
Previously, I did not run slow test, so I missed this bug
* Add simple TFDPRModelIntegrationTest
Note that this is just a test that TF and Pytorch gives approx. the same output.
However, I could not test with the official DPR repo's output yet
* upload correct tf model
* remove position_ids as missing keys
* fix RagSeq generate with context_input_ids
fix RagSeq generate with context_input_ids
* apply style
* delete unused lines
* Add test_rag_sequence_generate_batch_from_context_input_ids
* Readability improved
* stylying
* Stylize
* typos
* add check_model_generate_from_context_input_ids
* make style
* Apply suggestions from code review
* make style2
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: patrickvonplaten <patrick@huggingface.co>
2020-12-24 15:38:00 +03:00
def check_model_generate_from_context_input_ids (
self , config , input_ids , attention_mask , decoder_input_ids , decoder_attention_mask , * * kwargs
) :
self . assertIsNotNone ( config . question_encoder )
self . assertIsNotNone ( config . generator )
retriever = self . get_retriever ( config )
for model_class in self . all_model_classes :
model = model_class ( config ) . to ( torch_device )
model . eval ( )
self . assertTrue ( model . config . is_encoder_decoder )
question_hidden_states = model . question_encoder ( input_ids , attention_mask = attention_mask ) [ 0 ]
out = retriever (
input_ids ,
question_hidden_states . cpu ( ) . detach ( ) . to ( torch . float32 ) . numpy ( ) ,
prefix = config . generator . prefix ,
return_tensors = " pt " ,
)
context_input_ids , context_attention_mask , retrieved_doc_embeds = (
out [ " context_input_ids " ] ,
out [ " context_attention_mask " ] ,
out [ " retrieved_doc_embeds " ] ,
)
# cast
retrieved_doc_embeds = retrieved_doc_embeds . to ( question_hidden_states )
context_input_ids = context_input_ids . to ( input_ids )
context_attention_mask = context_attention_mask . to ( input_ids )
# compute doc_scores
doc_scores = torch . bmm ( question_hidden_states . unsqueeze ( 1 ) , retrieved_doc_embeds . transpose ( 1 , 2 ) ) . squeeze (
1
)
outputs = model . generate (
context_input_ids = context_input_ids ,
context_attention_mask = context_attention_mask ,
doc_scores = doc_scores ,
do_deduplication = True ,
)
self . assertIsNotNone ( outputs )
2020-09-22 19:29:58 +03:00
def check_model_generate (
self , config , input_ids , attention_mask , decoder_input_ids , decoder_attention_mask , * * kwargs
) :
self . assertIsNotNone ( config . question_encoder )
self . assertIsNotNone ( config . generator )
for model_class in self . all_model_classes [ 1 : ] :
model = model_class ( config , retriever = self . get_retriever ( config ) ) . to ( torch_device )
model . eval ( )
self . assertTrue ( model . config . is_encoder_decoder )
outputs = model . generate (
input_ids = input_ids ,
num_beams = 2 ,
num_return_sequences = 2 ,
decoder_start_token_id = config . generator . eos_token_id ,
)
self . assertIsNotNone ( outputs )
def check_model_without_retriever (
self , config , input_ids , attention_mask , decoder_input_ids , decoder_attention_mask , * * kwargs
) :
self . assertIsNotNone ( config . question_encoder )
self . assertIsNotNone ( config . generator )
retriever = self . get_retriever ( config )
for model_class in self . all_model_classes :
model = model_class ( config ) . to ( torch_device )
model . eval ( )
self . assertTrue ( model . config . is_encoder_decoder )
question_hidden_states = model . question_encoder ( input_ids , attention_mask = attention_mask ) [ 0 ]
out = retriever (
input_ids ,
question_hidden_states . cpu ( ) . detach ( ) . to ( torch . float32 ) . numpy ( ) ,
prefix = config . generator . prefix ,
return_tensors = " pt " ,
)
context_input_ids , context_attention_mask , retrieved_doc_embeds = (
out [ " context_input_ids " ] ,
out [ " context_attention_mask " ] ,
out [ " retrieved_doc_embeds " ] ,
)
# cast
retrieved_doc_embeds = retrieved_doc_embeds . to ( question_hidden_states )
context_input_ids = context_input_ids . to ( input_ids )
context_attention_mask = context_attention_mask . to ( input_ids )
# compute doc_scores
doc_scores = torch . bmm ( question_hidden_states . unsqueeze ( 1 ) , retrieved_doc_embeds . transpose ( 1 , 2 ) ) . squeeze (
1
)
outputs = model (
context_input_ids = context_input_ids ,
context_attention_mask = context_attention_mask ,
doc_scores = doc_scores ,
decoder_input_ids = decoder_input_ids ,
decoder_attention_mask = decoder_attention_mask ,
)
# logits
self . assertEqual (
outputs . logits . shape ,
( self . n_docs * decoder_input_ids . shape [ 0 ] , decoder_input_ids . shape [ 1 ] , config . generator . vocab_size ) ,
)
# generator encoder last hidden states
self . assertEqual (
outputs . generator_enc_last_hidden_state . shape ,
( self . n_docs * decoder_input_ids . shape [ 0 ] , self . max_combined_length , config . generator . hidden_size ) ,
)
# doc scores
self . assertEqual ( outputs . doc_scores . shape , ( input_ids . shape [ 0 ] , self . n_docs ) )
2020-10-19 16:15:52 +03:00
def check_model_custom_n_docs (
self , config , input_ids , attention_mask , decoder_input_ids , decoder_attention_mask , n_docs , * * kwargs
) :
self . assertIsNotNone ( config . question_encoder )
self . assertIsNotNone ( config . generator )
retriever = self . get_retriever ( config )
for model_class in self . all_model_classes :
model = model_class ( config ) . to ( torch_device )
model . eval ( )
self . assertTrue ( model . config . is_encoder_decoder )
question_hidden_states = model . question_encoder ( input_ids , attention_mask = attention_mask ) [ 0 ]
out = retriever (
input_ids ,
question_hidden_states . cpu ( ) . detach ( ) . to ( torch . float32 ) . numpy ( ) ,
prefix = config . generator . prefix ,
return_tensors = " pt " ,
n_docs = n_docs ,
)
context_input_ids , context_attention_mask , retrieved_doc_embeds = (
out [ " context_input_ids " ] ,
out [ " context_attention_mask " ] ,
out [ " retrieved_doc_embeds " ] ,
)
# cast
retrieved_doc_embeds = retrieved_doc_embeds . to ( question_hidden_states )
context_input_ids = context_input_ids . to ( input_ids )
context_attention_mask = context_attention_mask . to ( input_ids )
# compute doc_scores
doc_scores = torch . bmm ( question_hidden_states . unsqueeze ( 1 ) , retrieved_doc_embeds . transpose ( 1 , 2 ) ) . squeeze (
1
)
outputs = model (
context_input_ids = context_input_ids ,
context_attention_mask = context_attention_mask ,
doc_scores = doc_scores ,
decoder_input_ids = decoder_input_ids ,
decoder_attention_mask = decoder_attention_mask ,
n_docs = n_docs ,
)
# logits
self . assertEqual (
outputs . logits . shape ,
( n_docs * decoder_input_ids . shape [ 0 ] , decoder_input_ids . shape [ 1 ] , config . generator . vocab_size ) ,
)
# generator encoder last hidden states
self . assertEqual (
outputs . generator_enc_last_hidden_state . shape ,
( n_docs * decoder_input_ids . shape [ 0 ] , self . max_combined_length , config . generator . hidden_size ) ,
)
# doc scores
self . assertEqual ( outputs . doc_scores . shape , ( input_ids . shape [ 0 ] , n_docs ) )
def check_model_with_mismatch_n_docs_value (
self ,
config ,
input_ids ,
attention_mask ,
decoder_input_ids ,
decoder_attention_mask ,
retriever_n_docs ,
generator_n_docs ,
* * kwargs
) :
self . assertIsNotNone ( config . question_encoder )
self . assertIsNotNone ( config . generator )
retriever = self . get_retriever ( config )
for model_class in self . all_model_classes :
model = model_class ( config ) . to ( torch_device )
model . eval ( )
self . assertTrue ( model . config . is_encoder_decoder )
question_hidden_states = model . question_encoder ( input_ids , attention_mask = attention_mask ) [ 0 ]
out = retriever (
input_ids ,
question_hidden_states . cpu ( ) . detach ( ) . to ( torch . float32 ) . numpy ( ) ,
prefix = config . generator . prefix ,
return_tensors = " pt " ,
n_docs = retriever_n_docs ,
)
context_input_ids , context_attention_mask , retrieved_doc_embeds = (
out [ " context_input_ids " ] ,
out [ " context_attention_mask " ] ,
out [ " retrieved_doc_embeds " ] ,
)
# cast
retrieved_doc_embeds = retrieved_doc_embeds . to ( question_hidden_states )
context_input_ids = context_input_ids . to ( input_ids )
context_attention_mask = context_attention_mask . to ( input_ids )
# compute doc_scores
doc_scores = torch . bmm ( question_hidden_states . unsqueeze ( 1 ) , retrieved_doc_embeds . transpose ( 1 , 2 ) ) . squeeze (
1
)
self . assertRaises (
AssertionError ,
model . __call__ ,
context_input_ids = context_input_ids ,
context_attention_mask = context_attention_mask ,
doc_scores = doc_scores ,
decoder_input_ids = decoder_input_ids ,
decoder_attention_mask = decoder_attention_mask ,
n_docs = generator_n_docs ,
)
2020-09-22 19:29:58 +03:00
def check_model_with_encoder_outputs (
self , config , input_ids , attention_mask , decoder_input_ids , decoder_attention_mask , * * kwargs
) :
self . assertIsNotNone ( config . question_encoder )
self . assertIsNotNone ( config . generator )
for model_class in self . all_model_classes :
model = model_class ( config , retriever = self . get_retriever ( config ) ) . to ( torch_device )
model . eval ( )
self . assertTrue ( model . config . is_encoder_decoder )
outputs = model (
input_ids = input_ids ,
attention_mask = attention_mask ,
decoder_input_ids = decoder_input_ids ,
decoder_attention_mask = decoder_attention_mask ,
)
encoder_outputs = BaseModelOutput ( outputs . generator_enc_last_hidden_state )
# run only generator
outputs = model (
encoder_outputs = encoder_outputs ,
doc_scores = outputs . doc_scores ,
decoder_input_ids = decoder_input_ids ,
decoder_attention_mask = decoder_attention_mask ,
)
# logits
self . assertEqual (
outputs . logits . shape ,
( self . n_docs * decoder_input_ids . shape [ 0 ] , decoder_input_ids . shape [ 1 ] , config . generator . vocab_size ) ,
)
# generator encoder last hidden states
self . assertEqual (
outputs . generator_enc_last_hidden_state . shape ,
( self . n_docs * decoder_input_ids . shape [ 0 ] , self . max_combined_length , config . generator . hidden_size ) ,
)
# doc scores
self . assertEqual ( outputs . doc_scores . shape , ( input_ids . shape [ 0 ] , self . n_docs ) )
def test_model_with_retriever ( self ) :
inputs_dict = self . config_and_inputs
self . check_model_with_retriever ( * * inputs_dict )
def test_model_without_retriever ( self ) :
inputs_dict = self . config_and_inputs
self . check_model_without_retriever ( * * inputs_dict )
def test_model_with_encoder_outputs ( self ) :
inputs_dict = self . config_and_inputs
self . check_model_with_encoder_outputs ( * * inputs_dict )
def test_model_generate ( self ) :
inputs_dict = self . config_and_inputs
self . check_model_generate ( * * inputs_dict )
2020-10-19 16:15:52 +03:00
def test_model_with_custom_n_docs ( self ) :
inputs_dict = self . config_and_inputs
inputs_dict [ " n_docs " ] = 1
self . check_model_custom_n_docs ( * * inputs_dict )
def test_model_with_mismatch_n_docs_value ( self ) :
inputs_dict = self . config_and_inputs
inputs_dict [ " retriever_n_docs " ] = 3
inputs_dict [ " generator_n_docs " ] = 2
self . check_model_with_mismatch_n_docs_value ( * * inputs_dict )
2020-09-22 19:29:58 +03:00
@require_torch
@require_retrieval
class RagDPRBartTest ( RagTestMixin , unittest . TestCase ) :
@cached_property
def config_and_inputs ( self ) :
question_encoder_tester = DPRModelTester ( self )
dpr_config_and_inputs = question_encoder_tester . prepare_config_and_inputs ( )
generator_tester = BartModelTester ( self )
bart_config_and_inputs = generator_tester . prepare_config_and_inputs_for_common ( )
( question_encoder_config , input_ids , _ , input_mask , _ , _ , _ ) = dpr_config_and_inputs
( generator_config , bart_inputs_dict ) = bart_config_and_inputs
decoder_input_ids , decoder_attention_mask = bart_inputs_dict [ " input_ids " ] , bart_inputs_dict [ " attention_mask " ]
config = RagConfig . from_question_encoder_generator_configs (
question_encoder_config ,
generator_config ,
n_docs = self . n_docs ,
retrieval_vector_size = self . retrieval_vector_size ,
max_combined_length = self . max_combined_length ,
)
return {
" config " : config ,
" input_ids " : input_ids ,
" attention_mask " : input_mask ,
" decoder_input_ids " : decoder_input_ids ,
" decoder_attention_mask " : decoder_attention_mask ,
}
@require_torch
@require_retrieval
class RagDPRT5Test ( RagTestMixin , unittest . TestCase ) :
@cached_property
def config_and_inputs ( self ) :
question_encoder_tester = DPRModelTester ( self )
dpr_config_and_inputs = question_encoder_tester . prepare_config_and_inputs ( )
2020-11-30 10:34:40 +03:00
generator_tester = T5ModelTester ( self , vocab_size = 1100 )
2020-09-22 19:29:58 +03:00
t5_config_and_inputs = generator_tester . prepare_config_and_inputs ( )
( question_encoder_config , input_ids , _ , input_mask , _ , _ , _ ) = dpr_config_and_inputs
( generator_config , _ , decoder_input_ids , _ , decoder_attention_mask , _ ) = t5_config_and_inputs
config = RagConfig . from_question_encoder_generator_configs (
question_encoder_config ,
generator_config ,
n_docs = self . n_docs ,
retrieval_vector_size = self . retrieval_vector_size ,
max_combined_length = self . max_combined_length ,
)
return {
" config " : config ,
" input_ids " : input_ids ,
" attention_mask " : input_mask ,
" decoder_input_ids " : decoder_input_ids ,
" decoder_attention_mask " : decoder_attention_mask ,
}
@require_torch
@require_retrieval
2020-10-18 21:51:24 +03:00
@require_sentencepiece
@require_tokenizers
2020-11-10 21:23:58 +03:00
@require_torch_non_multi_gpu
2020-09-22 19:29:58 +03:00
class RagModelIntegrationTests ( unittest . TestCase ) :
@cached_property
def sequence_model ( self ) :
return (
RagSequenceForGeneration . from_pretrained_question_encoder_generator (
" facebook/dpr-question_encoder-single-nq-base " , " facebook/bart-large-cnn "
)
. to ( torch_device )
. eval ( )
)
@cached_property
def token_model ( self ) :
return (
RagTokenForGeneration . from_pretrained_question_encoder_generator (
" facebook/dpr-question_encoder-single-nq-base " , " facebook/bart-large-cnn "
)
. to ( torch_device )
. eval ( )
)
def get_rag_config ( self ) :
question_encoder_config = AutoConfig . from_pretrained ( " facebook/dpr-question_encoder-single-nq-base " )
generator_config = AutoConfig . from_pretrained ( " facebook/bart-large-cnn " )
return RagConfig . from_question_encoder_generator_configs (
question_encoder_config ,
generator_config ,
bos_token_id = 0 ,
decoder_start_token_id = 2 ,
eos_token_id = 2 ,
is_encoder_decoder = True ,
pad_token_id = 1 ,
vocab_size = 50264 ,
title_sep = " / " ,
doc_sep = " // " ,
n_docs = 5 ,
max_combined_length = 300 ,
dataset = " wiki_dpr " ,
dataset_split = " train " ,
index_name = " exact " ,
index_path = None ,
use_dummy_dataset = True ,
retrieval_vector_size = 768 ,
retrieval_batch_size = 8 ,
)
@slow
def test_rag_sequence_inference ( self ) :
rag_config = self . get_rag_config ( )
rag_decoder_tokenizer = BartTokenizer . from_pretrained ( " facebook/bart-large-cnn " )
rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer . from_pretrained (
" facebook/dpr-question_encoder-single-nq-base "
)
rag_retriever = RagRetriever (
rag_config ,
question_encoder_tokenizer = rag_question_encoder_tokenizer ,
generator_tokenizer = rag_decoder_tokenizer ,
)
rag_sequence = self . sequence_model
rag_sequence . set_retriever ( rag_retriever )
input_ids = rag_question_encoder_tokenizer (
" who sings does he love me with reba " , return_tensors = " pt "
) . input_ids
decoder_input_ids = rag_decoder_tokenizer ( " Linda Davis " , return_tensors = " pt " ) . input_ids
input_ids = input_ids . to ( torch_device )
decoder_input_ids = decoder_input_ids . to ( torch_device )
with torch . no_grad ( ) :
output = rag_sequence (
input_ids ,
labels = decoder_input_ids ,
)
expected_shape = torch . Size ( [ 5 , 5 , 50264 ] )
self . assertEqual ( output . logits . shape , expected_shape )
expected_doc_scores = torch . tensor ( [ [ 75.0286 , 74.4998 , 74.0804 , 74.0306 , 73.9504 ] ] ) . to ( torch_device )
_assert_tensors_equal ( expected_doc_scores , output . doc_scores , atol = TOLERANCE )
2020-09-25 17:12:46 +03:00
expected_loss = torch . tensor ( [ 36.7368 ] ) . to ( torch_device )
2020-09-22 19:29:58 +03:00
_assert_tensors_equal ( expected_loss , output . loss , atol = TOLERANCE )
@slow
def test_rag_token_inference ( self ) :
rag_config = self . get_rag_config ( )
rag_decoder_tokenizer = BartTokenizer . from_pretrained ( " facebook/bart-large-cnn " )
rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer . from_pretrained (
" facebook/dpr-question_encoder-single-nq-base "
)
rag_retriever = RagRetriever (
rag_config ,
question_encoder_tokenizer = rag_question_encoder_tokenizer ,
generator_tokenizer = rag_decoder_tokenizer ,
)
rag_token = self . token_model
rag_token . set_retriever ( rag_retriever )
input_ids = rag_question_encoder_tokenizer (
" who sings does he love me with reba " , return_tensors = " pt "
) . input_ids
decoder_input_ids = rag_decoder_tokenizer ( " Linda Davis " , return_tensors = " pt " ) . input_ids
input_ids = input_ids . to ( torch_device )
decoder_input_ids = decoder_input_ids . to ( torch_device )
with torch . no_grad ( ) :
output = rag_token (
input_ids ,
labels = decoder_input_ids ,
)
expected_shape = torch . Size ( [ 5 , 5 , 50264 ] )
self . assertEqual ( output . logits . shape , expected_shape )
expected_doc_scores = torch . tensor ( [ [ 75.0286 , 74.4998 , 74.0804 , 74.0306 , 73.9504 ] ] ) . to ( torch_device )
_assert_tensors_equal ( expected_doc_scores , output . doc_scores , atol = TOLERANCE )
2020-09-25 17:12:46 +03:00
expected_loss = torch . tensor ( [ 36.3557 ] ) . to ( torch_device )
2020-09-22 19:29:58 +03:00
_assert_tensors_equal ( expected_loss , output . loss , atol = TOLERANCE )
@slow
def test_rag_token_generate_beam ( self ) :
rag_config = self . get_rag_config ( )
rag_decoder_tokenizer = BartTokenizer . from_pretrained ( " facebook/bart-large-cnn " )
rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer . from_pretrained (
" facebook/dpr-question_encoder-single-nq-base "
)
rag_retriever = RagRetriever (
rag_config ,
question_encoder_tokenizer = rag_question_encoder_tokenizer ,
generator_tokenizer = rag_decoder_tokenizer ,
)
rag_token = self . token_model
rag_token . set_retriever ( rag_retriever )
input_ids = rag_question_encoder_tokenizer (
" who sings does he love me with reba " , return_tensors = " pt "
) . input_ids
input_ids = input_ids . to ( torch_device )
output_ids = rag_token . generate (
input_ids ,
decoder_start_token_id = rag_token . generator . config . decoder_start_token_id ,
num_beams = 2 ,
num_return_sequences = 2 ,
)
# sequence generate test
output_text_1 = rag_decoder_tokenizer . decode ( output_ids [ 0 ] , skip_special_tokens = True )
output_text_2 = rag_decoder_tokenizer . decode ( output_ids [ 1 ] , skip_special_tokens = True )
# Expected outputs as given by model at integration time.
2020-09-25 17:12:46 +03:00
EXPECTED_OUTPUT_TEXT_1 = " \" She ' s My Kind of Girl "
EXPECTED_OUTPUT_TEXT_2 = " \" She ' s My Kind of Love "
2020-09-22 19:29:58 +03:00
self . assertEqual ( output_text_1 , EXPECTED_OUTPUT_TEXT_1 )
self . assertEqual ( output_text_2 , EXPECTED_OUTPUT_TEXT_2 )
@slow
2020-09-25 17:12:46 +03:00
def test_rag_sequence_generate_beam ( self ) :
2020-09-22 19:29:58 +03:00
rag_config = self . get_rag_config ( )
rag_decoder_tokenizer = BartTokenizer . from_pretrained ( " facebook/bart-large-cnn " )
rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer . from_pretrained (
" facebook/dpr-question_encoder-single-nq-base "
)
rag_retriever = RagRetriever (
rag_config ,
question_encoder_tokenizer = rag_question_encoder_tokenizer ,
generator_tokenizer = rag_decoder_tokenizer ,
)
2020-12-14 14:32:26 +03:00
rag_sequence = self . sequence_model
rag_sequence . set_retriever ( rag_retriever )
2020-09-22 19:29:58 +03:00
2020-09-25 17:12:46 +03:00
input_ids = rag_question_encoder_tokenizer (
" who sings does he love me with reba " , return_tensors = " pt "
) . input_ids
2020-09-22 19:29:58 +03:00
2020-09-25 17:12:46 +03:00
input_ids = input_ids . to ( torch_device )
2020-09-22 19:29:58 +03:00
2020-12-14 14:32:26 +03:00
output_ids = rag_sequence . generate (
2020-09-22 19:29:58 +03:00
input_ids ,
2020-12-14 14:32:26 +03:00
decoder_start_token_id = rag_sequence . generator . config . decoder_start_token_id ,
2020-09-25 17:12:46 +03:00
num_beams = 2 ,
num_return_sequences = 2 ,
2020-09-22 19:29:58 +03:00
)
# sequence generate test
output_text_1 = rag_decoder_tokenizer . decode ( output_ids [ 0 ] , skip_special_tokens = True )
output_text_2 = rag_decoder_tokenizer . decode ( output_ids [ 1 ] , skip_special_tokens = True )
# Expected outputs as given by model at integration time.
2020-09-25 17:12:46 +03:00
EXPECTED_OUTPUT_TEXT_1 = """ \" She ' s My Kind of Girl \" was released through Epic Records in Japan in March 1972, giving the duo a Top 10 hit. Two more singles were released in Japan, \" En Carousel \" and \" Love Has Its Ways \" Ulvaeus and Andersson persevered with their songwriting and experimented with new sounds and vocal arrangements. """
EXPECTED_OUTPUT_TEXT_2 = """ In September 2018, Björn Ulvaeus revealed that the two new songs, \" I Still Have Faith In You \" and \" Don ' t Shut Me Down \" , would be released no earlier than March 2019. The two new tracks will feature in a TV special set to air later in the year. """
2020-09-22 19:29:58 +03:00
self . assertEqual ( output_text_1 , EXPECTED_OUTPUT_TEXT_1 )
self . assertEqual ( output_text_2 , EXPECTED_OUTPUT_TEXT_2 )
2020-09-25 17:12:46 +03:00
@property
def test_data_questions ( self ) :
return [
" who got the first nobel prize in physics " ,
" when is the next deadpool movie being released " ,
" which mode is used for short wave broadcast service " ,
" who is the owner of reading football club " ,
" when is the next scandal episode coming out " ,
" when is the last time the philadelphia won the superbowl " ,
" what is the most current adobe flash player version " ,
" how many episodes are there in dragon ball z " ,
]
2020-09-22 19:29:58 +03:00
@slow
def test_rag_sequence_generate_batch ( self ) :
2020-09-25 17:12:46 +03:00
tokenizer = RagTokenizer . from_pretrained ( " facebook/rag-sequence-nq " )
retriever = RagRetriever . from_pretrained (
" facebook/rag-sequence-nq " , index_name = " exact " , use_dummy_dataset = True
2020-09-22 19:29:58 +03:00
)
2020-12-14 14:32:26 +03:00
rag_sequence = RagSequenceForGeneration . from_pretrained ( " facebook/rag-sequence-nq " , retriever = retriever ) . to (
2020-09-25 17:12:46 +03:00
torch_device
2020-09-22 19:29:58 +03:00
)
2020-09-25 17:12:46 +03:00
input_dict = tokenizer (
self . test_data_questions ,
2020-09-22 19:29:58 +03:00
return_tensors = " pt " ,
padding = True ,
truncation = True ,
2020-09-25 00:22:04 +03:00
)
2020-09-22 19:29:58 +03:00
2020-09-25 00:22:04 +03:00
input_ids = input_dict . input_ids . to ( torch_device )
attention_mask = input_dict . attention_mask . to ( torch_device )
2020-09-22 19:29:58 +03:00
output_ids = rag_sequence . generate (
input_ids ,
2020-09-25 00:22:04 +03:00
attention_mask = attention_mask ,
2020-09-22 19:29:58 +03:00
)
2020-09-25 17:12:46 +03:00
outputs = tokenizer . batch_decode ( output_ids , skip_special_tokens = True )
EXPECTED_OUTPUTS = [
" albert einstein " ,
" june 22, 2018 " ,
" amplitude modulation " ,
" tim besley ( chairman ) " ,
" june 20, 2018 " ,
" 1980 " ,
" 7.0 " ,
" 8 " ,
]
self . assertListEqual ( outputs , EXPECTED_OUTPUTS )
Proposed Fix : [RagSequenceForGeneration] generate "without" input_ids (#9220)
* Create modeling_tf_dpr.py
* Add TFDPR
* Add back TFPegasus, TFMarian, TFMBart, TFBlenderBot
last commit accidentally deleted these 4 lines, so I recover them back
* Add TFDPR
* Add TFDPR
* clean up some comments, add TF input-style doc string
* Add TFDPR
* Make return_dict=False as default
* Fix return_dict bug (in .from_pretrained)
* Add get_input_embeddings()
* Create test_modeling_tf_dpr.py
The current version is already passed all 27 tests!
Please see the test run at :
https://colab.research.google.com/drive/1czS_m9zy5k-iSJbzA_DP1k1xAAC_sdkf?usp=sharing
* fix quality
* delete init weights
* run fix copies
* fix repo consis
* del config_class, load_tf_weights
They shoud be 'pytorch only'
* add config_class back
after removing it, test failed ... so totally only removing "use_tf_weights = None" on Lysandre suggestion
* newline after .. note::
* import tf, np (Necessary for ModelIntegrationTest)
* slow_test from_pretrained with from_pt=True
At the moment we don't have TF weights (since we don't have official official TF model)
Previously, I did not run slow test, so I missed this bug
* Add simple TFDPRModelIntegrationTest
Note that this is just a test that TF and Pytorch gives approx. the same output.
However, I could not test with the official DPR repo's output yet
* upload correct tf model
* remove position_ids as missing keys
* fix RagSeq generate with context_input_ids
fix RagSeq generate with context_input_ids
* apply style
* delete unused lines
* Add test_rag_sequence_generate_batch_from_context_input_ids
* Readability improved
* stylying
* Stylize
* typos
* add check_model_generate_from_context_input_ids
* make style
* Apply suggestions from code review
* make style2
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: patrickvonplaten <patrick@huggingface.co>
2020-12-24 15:38:00 +03:00
@slow
def test_rag_sequence_generate_batch_from_context_input_ids ( self ) :
tokenizer = RagTokenizer . from_pretrained ( " facebook/rag-sequence-nq " )
retriever = RagRetriever . from_pretrained (
" facebook/rag-sequence-nq " , index_name = " exact " , use_dummy_dataset = True
)
rag_sequence = RagSequenceForGeneration . from_pretrained ( " facebook/rag-sequence-nq " , retriever = retriever ) . to (
torch_device
)
input_dict = tokenizer (
self . test_data_questions ,
return_tensors = " pt " ,
padding = True ,
truncation = True ,
)
input_ids = input_dict . input_ids . to ( torch_device )
attention_mask = input_dict . attention_mask . to ( torch_device )
question_hidden_states = rag_sequence . question_encoder ( input_ids , attention_mask = attention_mask ) [ 0 ]
docs_dict = retriever (
input_ids . cpu ( ) . detach ( ) . numpy ( ) , question_hidden_states . cpu ( ) . detach ( ) . numpy ( ) , return_tensors = " pt "
)
doc_scores = torch . bmm (
question_hidden_states . unsqueeze ( 1 ) ,
docs_dict [ " retrieved_doc_embeds " ] . to ( torch_device ) . float ( ) . transpose ( 1 , 2 ) ,
) . squeeze ( 1 )
output_ids = rag_sequence . generate (
context_input_ids = docs_dict [ " context_input_ids " ] . to ( torch_device ) ,
context_attention_mask = docs_dict [ " context_attention_mask " ] . to ( torch_device ) ,
doc_scores = doc_scores . to ( torch_device ) ,
do_deduplication = True ,
)
outputs = tokenizer . batch_decode ( output_ids , skip_special_tokens = True )
EXPECTED_OUTPUTS = [
" albert einstein " ,
" june 22, 2018 " ,
" amplitude modulation " ,
" tim besley ( chairman ) " ,
" june 20, 2018 " ,
" 1980 " ,
" 7.0 " ,
" 8 " ,
]
self . assertListEqual ( outputs , EXPECTED_OUTPUTS )
2020-09-22 19:29:58 +03:00
@slow
2020-09-25 17:12:46 +03:00
def test_rag_token_generate_batch ( self ) :
tokenizer = RagTokenizer . from_pretrained ( " facebook/rag-token-nq " )
retriever = RagRetriever . from_pretrained ( " facebook/rag-token-nq " , index_name = " exact " , use_dummy_dataset = True )
rag_token = RagTokenForGeneration . from_pretrained ( " facebook/rag-token-nq " , retriever = retriever ) . to (
torch_device
2020-09-22 19:29:58 +03:00
)
2020-09-25 17:12:46 +03:00
input_dict = tokenizer (
self . test_data_questions ,
return_tensors = " pt " ,
padding = True ,
truncation = True ,
)
2020-09-22 19:29:58 +03:00
2020-09-25 17:12:46 +03:00
input_ids = input_dict . input_ids . to ( torch_device )
attention_mask = input_dict . attention_mask . to ( torch_device )
2020-09-22 19:29:58 +03:00
output_ids = rag_token . generate (
input_ids ,
2020-09-25 17:12:46 +03:00
attention_mask = attention_mask ,
2020-09-22 19:29:58 +03:00
)
2020-09-25 17:12:46 +03:00
outputs = tokenizer . batch_decode ( output_ids , skip_special_tokens = True )
EXPECTED_OUTPUTS = [
" albert einstein " ,
" september 22, 2017 " ,
" amplitude modulation " ,
" stefan persson " ,
" april 20, 2018 " ,
" the 1970s " ,
" 7.1. 2 " ,
" 13 " ,
]
self . assertListEqual ( outputs , EXPECTED_OUTPUTS )
2020-09-22 19:29:58 +03:00
@require_torch
@require_retrieval
class RagModelSaveLoadTests ( unittest . TestCase ) :
def get_rag_config ( self ) :
question_encoder_config = AutoConfig . from_pretrained ( " facebook/dpr-question_encoder-single-nq-base " )
generator_config = AutoConfig . from_pretrained ( " facebook/bart-large-cnn " )
return RagConfig . from_question_encoder_generator_configs (
question_encoder_config ,
generator_config ,
bos_token_id = 0 ,
decoder_start_token_id = 2 ,
eos_token_id = 2 ,
is_encoder_decoder = True ,
pad_token_id = 1 ,
vocab_size = 50264 ,
title_sep = " / " ,
doc_sep = " // " ,
n_docs = 5 ,
max_combined_length = 300 ,
dataset = " wiki_dpr " ,
dataset_split = " train " ,
index_name = " exact " ,
index_path = None ,
use_dummy_dataset = True ,
retrieval_vector_size = 768 ,
retrieval_batch_size = 8 ,
)
@slow
def test_rag_sequence_from_pretrained ( self ) :
rag_config = self . get_rag_config ( )
rag_decoder_tokenizer = BartTokenizer . from_pretrained ( " facebook/bart-large-cnn " )
rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer . from_pretrained (
" facebook/dpr-question_encoder-single-nq-base "
)
rag_retriever = RagRetriever (
rag_config ,
question_encoder_tokenizer = rag_question_encoder_tokenizer ,
generator_tokenizer = rag_decoder_tokenizer ,
)
input_ids = rag_question_encoder_tokenizer (
" who sings does he love me with reba " , return_tensors = " pt "
) . input_ids
decoder_input_ids = rag_decoder_tokenizer ( " Linda Davis " , return_tensors = " pt " ) . input_ids
input_ids = input_ids . to ( torch_device )
decoder_input_ids = decoder_input_ids . to ( torch_device )
with tempfile . TemporaryDirectory ( ) as tmp_dirname :
rag_sequence = RagSequenceForGeneration . from_pretrained_question_encoder_generator (
" facebook/dpr-question_encoder-single-nq-base " ,
" facebook/bart-large-cnn " ,
retriever = rag_retriever ,
config = rag_config ,
) . to ( torch_device )
# check that the from pretrained methods work
rag_sequence . save_pretrained ( tmp_dirname )
rag_sequence . from_pretrained ( tmp_dirname , retriever = rag_retriever )
rag_sequence . to ( torch_device )
with torch . no_grad ( ) :
output = rag_sequence (
input_ids ,
labels = decoder_input_ids ,
)
loss_pretrained = output . loss
del rag_sequence
question_encoder = AutoModel . from_pretrained ( " facebook/dpr-question_encoder-single-nq-base " )
generator = AutoModelForSeq2SeqLM . from_pretrained ( " facebook/bart-large-cnn " )
rag_sequence = RagSequenceForGeneration (
config = rag_config , question_encoder = question_encoder , generator = generator , retriever = rag_retriever
)
rag_sequence . to ( torch_device )
with torch . no_grad ( ) :
output = rag_sequence (
input_ids ,
labels = decoder_input_ids ,
)
loss_init = output . loss
self . assertAlmostEqual ( loss_pretrained . item ( ) , loss_init . item ( ) , places = 4 )
@slow
def test_rag_token_from_pretrained ( self ) :
rag_config = self . get_rag_config ( )
rag_decoder_tokenizer = BartTokenizer . from_pretrained ( " facebook/bart-large-cnn " )
rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer . from_pretrained (
" facebook/dpr-question_encoder-single-nq-base "
)
rag_retriever = RagRetriever (
rag_config ,
question_encoder_tokenizer = rag_question_encoder_tokenizer ,
generator_tokenizer = rag_decoder_tokenizer ,
)
input_ids = rag_question_encoder_tokenizer (
" who sings does he love me with reba " , return_tensors = " pt "
) . input_ids
decoder_input_ids = rag_decoder_tokenizer ( " Linda Davis " , return_tensors = " pt " ) . input_ids
input_ids = input_ids . to ( torch_device )
decoder_input_ids = decoder_input_ids . to ( torch_device )
with tempfile . TemporaryDirectory ( ) as tmp_dirname :
rag_token = RagTokenForGeneration . from_pretrained_question_encoder_generator (
" facebook/dpr-question_encoder-single-nq-base " ,
" facebook/bart-large-cnn " ,
retriever = rag_retriever ,
config = rag_config ,
) . to ( torch_device )
# check that the from pretrained methods work
rag_token . save_pretrained ( tmp_dirname )
rag_token . from_pretrained ( tmp_dirname , retriever = rag_retriever )
rag_token . to ( torch_device )
with torch . no_grad ( ) :
output = rag_token (
input_ids ,
labels = decoder_input_ids ,
)
loss_pretrained = output . loss
del rag_token
question_encoder = AutoModel . from_pretrained ( " facebook/dpr-question_encoder-single-nq-base " )
generator = AutoModelForSeq2SeqLM . from_pretrained ( " facebook/bart-large-cnn " )
rag_token = RagTokenForGeneration (
config = rag_config , question_encoder = question_encoder , generator = generator , retriever = rag_retriever
)
rag_token . to ( torch_device )
with torch . no_grad ( ) :
output = rag_token (
input_ids ,
labels = decoder_input_ids ,
)
loss_init = output . loss
self . assertAlmostEqual ( loss_pretrained . item ( ) , loss_init . item ( ) , places = 4 )