diff --git a/.gitignore b/.gitignore index 7da929be8..5daaa51ed 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ __pycache__/ # tests and logs tests/fixtures logs/ +lightning_logs/ # Distribution / packaging .Python @@ -139,6 +140,7 @@ runs /wandb /examples/runs /examples/**/*.args +/examples/rag/sweep # data /data diff --git a/docs/source/index.rst b/docs/source/index.rst index 77f0a1b7d..8ff8d206b 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -231,6 +231,7 @@ conversion utilities for the following models: model_doc/lxmert model_doc/bertgeneration model_doc/layoutlm + model_doc/rag internal/modeling_utils internal/tokenization_utils internal/pipelines_utils diff --git a/docs/source/model_doc/rag.rst b/docs/source/model_doc/rag.rst new file mode 100644 index 000000000..46accdedb --- /dev/null +++ b/docs/source/model_doc/rag.rst @@ -0,0 +1,88 @@ +RAG +---------------------------------------------------- + +Overview +~~~~~~~~~~~~~~~~~~~~~ + +Retrieval-augmented generation ("RAG") models combine the powers of pretrained dense retrieval (DPR) and Seq2Seq models. +RAG models retrieve docs, pass them to a seq2seq model, then marginalize to generate outputs. +The retriever and seq2seq modules are initialized from pretrained models, and fine-tuned jointly, allowing both retrieval and generation to adapt to downstream tasks. + +It is based on the paper `Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks `__ by Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela. + +The abstract from the paper is the following: + +*Large pre-trained language models have been shown to store factual knowledge +in their parameters, and achieve state-of-the-art results when fine-tuned on +downstream NLP tasks. However, their ability to access and precisely manipulate +knowledge is still limited, and hence on knowledge-intensive tasks, their +performance lags behind task-specific architectures. Additionally, providing +provenance for their decisions and updating their world knowledge remain open +research problems. Pre-trained models with a differentiable access mechanism to +explicit nonparametric memory can overcome this issue, but have so far been only +investigated for extractive downstream tasks. We explore a general-purpose +fine-tuning recipe for retrieval-augmented generation (RAG) — models which combine +pre-trained parametric and non-parametric memory for language generation. We +introduce RAG models where the parametric memory is a pre-trained seq2seq model and +the non-parametric memory is a dense vector index of Wikipedia, accessed with +a pre-trained neural retriever. We compare two RAG formulations, one which +conditions on the same retrieved passages across the whole generated sequence, the +other can use different passages per token. We fine-tune and evaluate our models +on a wide range of knowledge-intensive NLP tasks and set the state-of-the-art +on three open domain QA tasks, outperforming parametric seq2seq models and +task-specific retrieve-and-extract architectures. For language generation tasks, we +find that RAG models generate more specific, diverse and factual language than a +state-of-the-art parametric-only seq2seq baseline.* + + + +RagConfig +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.RagConfig + :members: + + +RagTokenizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.RagTokenizer + :members: + + +Rag specific outputs +~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.modeling_rag.RetrievAugLMMarginOutput + :members: + +.. autoclass:: transformers.modeling_rag.RetrievAugLMOutput + :members: + + +RAGRetriever +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.RagRetriever + :members: + + +RagModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.RagModel + :members: forward + + +RagSequenceForGeneration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.RagSequenceForGeneration + :members: forward, generate + + +RagTokenForGeneration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.RagTokenForGeneration + :members: forward, generate diff --git a/docs/source/model_summary.rst b/docs/source/model_summary.rst index fe253a266..e88d8de02 100644 --- a/docs/source/model_summary.rst +++ b/docs/source/model_summary.rst @@ -654,7 +654,7 @@ DPR Models - + Doc @@ -672,6 +672,27 @@ DPR consists in three models: DPR's pipeline (not implemented yet) uses a retrieval step to find the top k contexts given a certain question, and then it calls the reader with the question and the retrieved documents to get the answer. +RAG +---------------------------------------------- + +.. raw:: html + + + Models + + + Doc + + +`Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks `_, +Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela + +Retrieval-augmented generation ("RAG") models combine the powers of pretrained dense retrieval (DPR) and Seq2Seq models. +RAG models retrieve docs, pass them to a seq2seq model, then marginalize to generate outputs. +The retriever and seq2seq modules are initialized from pretrained models, and fine-tuned jointly, allowing both retrieval and generation to adapt to downstream tasks. + +The two models RAG-Token and RAG-Sequence are available for generation. + More technical aspects ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/examples/lightning_base.py b/examples/lightning_base.py index e7c41a3e7..c66884d78 100644 --- a/examples/lightning_base.py +++ b/examples/lightning_base.py @@ -366,6 +366,8 @@ def generic_train( if args.gpus > 1: train_params["distributed_backend"] = "ddp" + train_params["accumulate_grad_batches"] = args.accumulate_grad_batches + trainer = pl.Trainer.from_argparse_args( args, weights_summary=None, diff --git a/examples/longform-qa/eli5_app.py b/examples/longform-qa/eli5_app.py index a57238df4..4bb8de178 100644 --- a/examples/longform-qa/eli5_app.py +++ b/examples/longform-qa/eli5_app.py @@ -1,10 +1,10 @@ import datasets -import faiss import numpy as np import streamlit as st import torch from elasticsearch import Elasticsearch +import faiss import transformers from eli5_utils import ( embed_questions_for_retrieval, diff --git a/examples/longform-qa/eli5_utils.py b/examples/longform-qa/eli5_utils.py index 95e1eaf6b..60bc424a7 100644 --- a/examples/longform-qa/eli5_utils.py +++ b/examples/longform-qa/eli5_utils.py @@ -5,7 +5,6 @@ from random import choice, randint from time import time import datasets # noqa: F401 -import faiss # noqa: F401 import numpy as np import pandas as pd import torch @@ -15,6 +14,7 @@ from elasticsearch.helpers import bulk, streaming_bulk # noqa: F401 from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler from tqdm import tqdm +import faiss # noqa: F401 from transformers import AdamW, AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer, get_linear_schedule_with_warmup diff --git a/examples/rag/README.md b/examples/rag/README.md new file mode 100644 index 000000000..e4453868b --- /dev/null +++ b/examples/rag/README.md @@ -0,0 +1,88 @@ +# Intro +RAG is a seq2seq model which encapsulates two core components: a question encoder and a generator. +During a forward pass, we encode the input with the question encoder and pass it +to the retriever to extract relevant context documents. The documents are then prepended to the input. +Such contextualized inputs is passed to the generator. + +The question encoder can be any `autoencoding` model, preferably :obj:`~transformers.DPRQuestionEncoder`, and the generator can be any `seq2seq` model, preferably :obj:`~transformers.BartForConditionalGeneration`. + +The model can be initialized with a :obj:`~transformers.RagRetriever` for end-to-end generation or used in combination with the outputs of a retriever in multiple steps - see examples for more details. +The model is compatible any `autoencoding` model as the ``question_encoder`` and any `seq2seq` model with language model head as the ``generator``. +The model has been tested with :class:`~transformers.DPRQuestionEncoder` as the ``question_encoder`` and :class:`~transformers.BartForConditionalGeneration` or :class:`~transformers.T5ForConditionalGeneration` as the ``generator``. + +RAG models were released with the paper `Retrieval-Augmented Generation for +Knowledge-Intensive NLP Tasks `_ by Patrick Lewis, Ethan Perez, Aleksandra Piktus et al. + + +# Finetuning +Our finetuning logic is based on scripts from [`examples/seq2seq`](https://github.com/huggingface/transformers/tree/master/examples/seq2seq). +Follow instructions there regarding data preprocessing. A sample finetuning command: + +``` +python examples/rag/finetune.py \ + --data_dir $DATA_DIR \ + --output_dir $OUTPUT_DIR \ + --model_name_or_path $MODEL_NAME_OR_PATH \ + --model_type rag_sequence \ + --fp16 \ + --gpus 8 +``` + + +# Evaluation +Apart from the parameters specifying the model to evaluate and some extra parameters, the evaluation script expects paths to two files: +- `evaluation_set` - a path to a file specifying the evaluation dataset, a single datapoint per line, e.g. +```who is the owner of reading football club``` +- `gold_data_path` - a path to a file contaning ground truth answers for datapoints from the `evaluation_set`. + +We expect the following formats of the gold data file: + +- for e2e evaluation, we support two formats of the gold file: + - `qa` - where a single line in the following format: input [tab] output_list, e.g.: + ``` + who is the owner of reading football club ['Xiu Li Dai', 'Dai Yongge', 'Dai Xiuli', 'Yongge Dai'] + ``` + - `ans` - where a single line of the gold file contains the expected output string, e.g.: + ``` + Xiu Li Dai + ``` + +- for retrieval evaluation, we expect a tab-separated list of Wikipedia page titles constituting positive contexts for a given query, e.g. given a question `who sings does he love me with reba`, a line with ground truth retrieval data could look as follows: +``` +Does He Love You Does He Love You Red Sandy Spika dress of Reba McEntire Greatest Hits Volume Two (Reba McEntire album) Shoot for the Moon (album) +``` + +## Retrieval evaluation + +We demonstrate how to evaluate retrieval against DPR evaluation data. You can download respective files from links listed [here](https://github.com/facebookresearch/DPR/blob/master/data/download_data.py#L39-L45). + +1. Download and unzip the gold data file. We use the `biencoder-nq-dev` from https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-dev.json.gz. +2. Parse the unziped file using the `parse_dpr_relevance_data.py` +``` +python examples/rag/parse_dpr_relevance_data.py --src_path path/to/unziped/biencoder-nq-dev.json --evaluation_set path/to/output/biencoder-nq-dev.questions --gold_data_path path/to/output/biencoder-nq-dev.pages +``` +3. Run evaluation: +``` +python examples/rag/eval_rag.py \ + --model_name_or_path $MODEL_NAME_OR_PATH \ # model name or path of the model we're evaluating + --model_type rag_sequence \ # RAG model type (rag_token or rag_sequence) + --evaluation_set path/to/output/biencoder-nq-dev.questions \ # an input dataset for evaluation + --gold_data_path path/to/output/biencoder-nq-dev.pages \ # a dataset containing ground truth answers for samples from the evaluation_set + --predictions_path path/to/retrieval_preds.tsv \ # name of file in which predictions will be stored + --eval_mode retrieval \ # indicates whether we're performing retrieval evaluation or e2e evaluation + --recalculate # if predictions_filename already exists, and this option is set - we regenerate the answers, otherwise we reuse the predicsion file to calculate metrics. +``` + + +## End-to-end evaluation +``` +python examples/rag/eval_rag.py \ + --model_name_or_path $MODEL_NAME_OR_PATH \ + --model_type rag_sequence \ + --evaluation_set path/to/test.source \ + --gold_data_path path/to/gold_data \ + --predictions_path path/to/e2e_preds.txt \ + --eval_mode e2e \ # indicates whether we're performing retrieval evaluation or e2e evaluation (default) + --n_docs 5 \ # You can experiment with retrieving different number of documents at evaluation time + --print_predictions +``` diff --git a/examples/rag/__init__.py b/examples/rag/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/rag/callbacks.py b/examples/rag/callbacks.py new file mode 100644 index 000000000..222db114d --- /dev/null +++ b/examples/rag/callbacks.py @@ -0,0 +1,30 @@ +import logging +import os + +from pytorch_lightning.callbacks import ModelCheckpoint + + +logger = logging.getLogger(__name__) + + +def get_checkpoint_callback(output_dir, metric): + """Saves the best model by validation EM score.""" + if metric == "rouge2": + exp = "{val_avg_rouge2:.4f}-{step_count}" + elif metric == "bleu": + exp = "{val_avg_bleu:.4f}-{step_count}" + elif metric == "em": + exp = "{val_avg_em:.4f}-{step_count}" + else: + raise NotImplementedError( + f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this function." + ) + + checkpoint_callback = ModelCheckpoint( + filepath=os.path.join(output_dir, exp), + monitor=f"val_{metric}", + mode="max", + save_top_k=3, + period=0, # maybe save a checkpoint every time val is run, not just end of epoch. + ) + return checkpoint_callback diff --git a/examples/rag/distributed_retriever.py b/examples/rag/distributed_retriever.py new file mode 100644 index 000000000..45af93c55 --- /dev/null +++ b/examples/rag/distributed_retriever.py @@ -0,0 +1,135 @@ +import logging +import os +from typing import List, Tuple + +import numpy as np +import psutil +import torch +import torch.distributed as dist + +from transformers import RagRetriever + + +logger = logging.getLogger(__name__) + + +class RagPyTorchDistributedRetriever(RagRetriever): + """ + A distributed retriever built on top of the ``torch.distributed`` communication package. During training all workers + initalize their own instance of the retriever, however, only the main worker loads the index into memory. The index is stored + in cpu memory. The index will also work well in a non-distributed setup. + + Args: + config (:class:`~transformers.RagConfig`): + The configuration of the RAG model this Retriever is used with. Contains parameters indicating which ``Index`` to build. + question_encoder_tokenizer (:class:`~transformers.PretrainedTokenizer`): + The tokenizer that was used to tokenize the question. + It is used to decode the question and then use the generator_tokenizer. + generator_tokenizer (:class:`~transformers.PretrainedTokenizer`): + The tokenizer used for the generator part of the RagModel. + """ + + _init_retrieval = False + + def __init__(self, config, question_encoder_tokenizer, generator_tokenizer): + super().__init__( + config, question_encoder_tokenizer=question_encoder_tokenizer, generator_tokenizer=generator_tokenizer + ) + + self.process_group = None + + def init_retrieval(self, distributed_port: int): + """ + Retriever initalization function, needs to be called from the training process. The function sets some common parameters + and environment variables. On top of that, (only) the main process in the process group loads the index into memory. + + Args: + distributed_port (:obj:`int`): + The port on which the main communication of the training run is carried out. We set the port for retrieval-related + communication as ``distributed_port + 1``. + """ + + logger.info("initializing retrieval") + + # initializing a separate process group for retrievel as the default + # nccl backend doesn't support gather/scatter operations while gloo + # is too slow to replace nccl for the core gpu communication + if dist.is_initialized(): + logger.info("dist initialized") + # needs to be set manually + os.environ["GLOO_SOCKET_IFNAME"] = self._infer_socket_ifname() + # avoid clash with the NCCL port + os.environ["MASTER_PORT"] = str(distributed_port + 1) + self.process_group = dist.new_group(ranks=None, backend="gloo") + + # initialize retriever only on the main worker + if not dist.is_initialized() or self._is_main(): + logger.info("dist not initialized / main") + self.index.init_index() + + # all processes wait untill the retriever is initialized by the main process + if dist.is_initialized(): + torch.distributed.barrier(group=self.process_group) + + def _is_main(self): + return dist.get_rank(group=self.process_group) == 0 + + def _scattered(self, scatter_list, target_shape, target_type=torch.float32): + target_tensor = torch.empty(target_shape, dtype=target_type) + dist.scatter(target_tensor, src=0, scatter_list=scatter_list, group=self.process_group) + return target_tensor + + def _infer_socket_ifname(self): + addrs = psutil.net_if_addrs() + # a hacky way to deal with varying network interface names + ifname = next((addr for addr in addrs if addr.startswith("e")), None) + return ifname + + def retrieve(self, question_hidden_states: np.ndarray, n_docs: int) -> Tuple[np.ndarray, List[dict]]: + """ + Retrieves documents for specified ``question_hidden_states``. The main process, which has the access to the index stored in memory, gathers queries + from all the processes in the main training process group, performs the retrieval and scatters back the results. + + Args: + question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`): + A batch of query vectors to retrieve with. + n_docs (:obj:`int`): + The number of docs retrieved per query. + + Ouput: + retrieved_doc_embeds (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs, dim)` + The retrieval embeddings of the retrieved docs per query. + doc_ids (:obj:`np.ndarray` of shape :obj:`batch_size, n_docs`) + The ids of the documents in the index + doc_dicts (:obj:`List[dict]`): + The retrieved_doc_embeds examples per query. + """ + + # single GPU training + if not dist.is_initialized(): + doc_ids, retrieved_doc_embeds = self._main_retrieve(question_hidden_states, n_docs) + return retrieved_doc_embeds, doc_ids, self.index.get_doc_dicts(doc_ids) + + # distributed training + world_size = dist.get_world_size(group=self.process_group) + + # gather logic + gather_list = None + if self._is_main(): + gather_list = [torch.empty(question_hidden_states.shape, dtype=torch.float32) for _ in range(world_size)] + dist.gather(torch.tensor(question_hidden_states), dst=0, gather_list=gather_list, group=self.process_group) + + # scatter logic + n_queries = question_hidden_states.shape[0] + scatter_ids = [] + scatter_vectors = [] + if self._is_main(): + assert len(gather_list) == world_size + ids, vectors = self._main_retrieve(torch.cat(gather_list).numpy(), n_docs) + ids, vectors = torch.tensor(ids), torch.tensor(vectors) + scatter_ids = self._chunk_tensor(ids, n_queries) + scatter_vectors = self._chunk_tensor(vectors, n_queries) + doc_ids = self._scattered(scatter_ids, [n_queries, n_docs], target_type=torch.int64) + retrieved_doc_embeds = self._scattered(scatter_vectors, [n_queries, n_docs, question_hidden_states.shape[1]]) + + return retrieved_doc_embeds.numpy(), doc_ids.numpy(), self.index.get_doc_dicts(doc_ids) diff --git a/examples/rag/eval_rag.py b/examples/rag/eval_rag.py new file mode 100644 index 000000000..a2fc93648 --- /dev/null +++ b/examples/rag/eval_rag.py @@ -0,0 +1,310 @@ +""" Evaluation script for RAG models.""" + +import argparse +import ast +import logging +import os +import sys + +import pandas as pd +import torch +from tqdm import tqdm + +from transformers import BartForConditionalGeneration, RagRetriever, RagSequenceForGeneration, RagTokenForGeneration +from transformers import logging as transformers_logging + + +sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # isort:skip +from examples.rag.utils import exact_match_score, f1_score # noqa: E402 # isort:skip + + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + +transformers_logging.set_verbosity_info() + + +def infer_model_type(model_name_or_path): + if "token" in model_name_or_path: + return "rag_token" + if "sequence" in model_name_or_path: + return "rag_sequence" + if "bart" in model_name_or_path: + return "bart" + return None + + +def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): + return max(metric_fn(prediction, gt) for gt in ground_truths) + + +def get_scores(args, preds_path, gold_data_path): + hypos = [line.strip() for line in open(preds_path, "r").readlines()] + answers = [] + + if args.gold_data_mode == "qa": + data = pd.read_csv(gold_data_path, sep="\t", header=None) + for answer_list in data[1]: + ground_truths = ast.literal_eval(answer_list) + answers.append(ground_truths) + else: + references = [line.strip() for line in open(gold_data_path, "r").readlines()] + answers = [[reference] for reference in references] + + f1 = em = total = 0 + for prediction, ground_truths in zip(hypos, answers): + total += 1 + em += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths) + f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths) + + em = 100.0 * em / total + f1 = 100.0 * f1 / total + + logger.info(f"F1: {f1:.2f}") + logger.info(f"EM: {em:.2f}") + + +def get_precision_at_k(args, preds_path, gold_data_path): + k = args.k + hypos = [line.strip() for line in open(preds_path, "r").readlines()] + references = [line.strip() for line in open(gold_data_path, "r").readlines()] + + em = total = 0 + for hypo, reference in zip(hypos, references): + hypo_provenance = set(hypo.split("\t")[:k]) + ref_provenance = set(reference.split("\t")[1 : (k + 1)]) + total += 1 + em += len(hypo_provenance & ref_provenance) / k + + em = 100.0 * em / total + logger.info(f"Precision@{k}: {em: .2f}") + + +def evaluate_batch_retrieval(args, rag_model, questions): + def strip_title(title): + if title.startswith('"'): + title = title[1:] + if title.endswith('"'): + title = title[:-1] + return title + + retriever_input_ids = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus( + questions, + return_tensors="pt", + padding=True, + truncation=True, + )["input_ids"].to(args.device) + + question_enc_outputs = rag_model.rag.question_encoder(retriever_input_ids, return_dict=True) + question_enc_pool_output = question_enc_outputs.pooler_output + + result = rag_model.retriever( + retriever_input_ids, + question_enc_pool_output.cpu().detach().to(torch.float32).numpy(), + prefix=rag_model.rag.generator.config.prefix, + n_docs=rag_model.config.n_docs, + return_tensors="pt", + ) + all_docs = rag_model.retriever.index.get_doc_dicts(result.doc_ids) + provenance_strings = [] + for docs in all_docs: + provenance = [strip_title(title) for title in docs["title"]] + provenance_strings.append("\t".join(provenance)) + return provenance_strings + + +def evaluate_batch_e2e(args, rag_model, questions): + with torch.no_grad(): + input_ids = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus( + questions, return_tensors="pt", padding=True, truncation=True + )["input_ids"].to(args.device) + outputs = rag_model.generate( # rag_model overwrites generate + input_ids, + num_beams=args.num_beams, + min_length=args.min_length, + max_length=args.max_length, + early_stopping=False, + num_return_sequences=1, + bad_words_ids=[[0, 0]], # BART likes to repeat BOS tokens, dont allow it to generate more than one + clean_up_tokenization=True, + print_docs=args.print_docs, + ) + answers = rag_model.retriever.generator_tokenizer.batch_decode(outputs, skip_special_tokens=True) + + if args.print_predictions: + for q, a in zip(questions, answers): + logger.info("Q: {} - A: {}".format(q, a)) + + return answers + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_type", + choices=["rag_sequence", "rag_token", "bart"], + type=str, + help="RAG model type: rag_sequence, rag_token or bart, if none specified, the type is inferred from the model_name_or_path", + ) + parser.add_argument( + "--index_name", + default=None, + choices=["hf", "legacy"], + type=str, + help="RAG model retriever type", + ) + parser.add_argument( + "--index_path", + default=None, + type=str, + help="Path to the retrieval index", + ) + parser.add_argument("--n_docs", default=5, type=int, help="Number of retrieved docs") + parser.add_argument( + "--model_name_or_path", + default=None, + type=str, + required=True, + help="Path to pretrained checkpoints or model identifier from huggingface.co/models", + ) + parser.add_argument( + "--eval_mode", + choices=["e2e", "retrieval"], + default="e2e", + type=str, + help="Evaluation mode, e2e calculates exact match and F1 of the downstream task, retrieval calulates precision@k.", + ) + parser.add_argument("--k", default=1, type=int, help="k for the precision@k calculation") + parser.add_argument( + "--evaluation_set", + default=None, + type=str, + required=True, + help="Path to a file containing evaluation samples", + ) + parser.add_argument( + "--gold_data_path", + default=None, + type=str, + required=True, + help="Path to a tab-separated file with gold samples", + ) + parser.add_argument( + "--gold_data_mode", + default="qa", + type=str, + choices=["qa", "ans"], + help="Format of the gold data file" + "qa - a single line in the following format: question [tab] answer_list" + "ans - a single line of the gold file contains the expected answer string", + ) + parser.add_argument( + "--predictions_path", + type=str, + default="predictions.txt", + help="Name of the predictions file, to be stored in the checkpoints directry", + ) + parser.add_argument( + "--eval_all_checkpoints", + action="store_true", + help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number", + ) + parser.add_argument( + "--eval_batch_size", + default=8, + type=int, + help="Batch size per GPU/CPU for evaluation.", + ) + parser.add_argument( + "--recalculate", + help="Recalculate predictions even if the prediction file exists", + action="store_true", + ) + parser.add_argument( + "--num_beams", + default=4, + type=int, + help="Number of beams to be used when generating answers", + ) + parser.add_argument("--min_length", default=1, type=int, help="Min length of the generated answers") + parser.add_argument("--max_length", default=50, type=int, help="Max length of the generated answers") + + parser.add_argument( + "--print_predictions", + action="store_true", + help="If True, prints predictions while evaluating.", + ) + parser.add_argument( + "--print_docs", + action="store_true", + help="If True, prints docs retried while generating.", + ) + args = parser.parse_args() + args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + return args + + +def main(args): + model_kwargs = {} + if args.model_type is None: + args.model_type = infer_model_type(args.model_name_or_path) + assert args.model_type is not None + if args.model_type.startswith("rag"): + model_class = RagTokenForGeneration if args.model_type == "rag_token" else RagSequenceForGeneration + model_kwargs["n_docs"] = args.n_docs + if args.index_name is not None: + model_kwargs["index_name"] = args.index_name + if args.index_path is not None: + model_kwargs["index_path"] = args.index_path + else: + model_class = BartForConditionalGeneration + + checkpoints = ( + [f.path for f in os.scandir(args.model_name_or_path) if f.is_dir()] + if args.eval_all_checkpoints + else [args.model_name_or_path] + ) + + logger.info("Evaluate the following checkpoints: %s", checkpoints) + + score_fn = get_scores if args.eval_mode == "e2e" else get_precision_at_k + evaluate_batch_fn = evaluate_batch_e2e if args.eval_mode == "e2e" else evaluate_batch_retrieval + + for checkpoint in checkpoints: + if os.path.exists(args.predictions_path) and (not args.recalculate): + logger.info("Calculating metrics based on an existing predictions file: {}".format(args.predictions_path)) + score_fn(args, args.predictions_path, args.gold_data_path) + continue + + logger.info("***** Running evaluation for {} *****".format(checkpoint)) + logger.info(" Batch size = %d", args.eval_batch_size) + logger.info(" Predictions will be stored under {}".format(args.predictions_path)) + + if args.model_type.startswith("rag"): + retriever = RagRetriever.from_pretrained(checkpoint, **model_kwargs) + model = model_class.from_pretrained(checkpoint, retriever=retriever, **model_kwargs) + model.retriever.init_retrieval() + else: + model = model_class.from_pretrained(checkpoint, **model_kwargs) + model.to(args.device) + + with open(args.evaluation_set, "r") as eval_file, open(args.predictions_path, "w") as preds_file: + questions = [] + for line in tqdm(eval_file): + questions.append(line.strip()) + if len(questions) == args.eval_batch_size: + answers = evaluate_batch_fn(args, model, questions) + preds_file.write("\n".join(answers) + "\n") + preds_file.flush() + questions = [] + if len(questions) > 0: + answers = evaluate_batch_fn(args, model, questions) + preds_file.write("\n".join(answers)) + preds_file.flush() + + score_fn(args, args.predictions_path, args.gold_data_path) + + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/examples/rag/finetune.py b/examples/rag/finetune.py new file mode 100644 index 000000000..8648e678b --- /dev/null +++ b/examples/rag/finetune.py @@ -0,0 +1,474 @@ +"""Finetuning script for RAG models. Adapted from examples.seq2seq.finetune.py""" + +import argparse +import glob +import logging +import os +import sys +import time +import warnings +from collections import defaultdict +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import numpy as np +import pytorch_lightning as pl +import torch +import torch.distributed as dist +from torch.utils.data import DataLoader + +from transformers import ( + AutoConfig, + AutoTokenizer, + BartForConditionalGeneration, + RagConfig, + RagSequenceForGeneration, + RagTokenForGeneration, + RagTokenizer, + T5ForConditionalGeneration, + get_linear_schedule_with_warmup, +) +from transformers import logging as transformers_logging + + +sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip + +from examples.lightning_base import BaseTransformer, add_generic_args, generic_train # noqa: E402 # isort:skip +from examples.rag.callbacks import get_checkpoint_callback # noqa: E402 # isort:skip +from examples.rag.distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip +from examples.rag.utils import ( # noqa: E402 # isort:skip + Seq2SeqDataset, + calculate_exact_match, + is_rag_model, + set_extra_model_params, +) +from examples.seq2seq.callbacks import Seq2SeqLoggingCallback, get_early_stopping_callback # noqa: E402 # isort:skip +from examples.seq2seq.utils import ( # noqa: E402 # isort:skip + flatten_list, + get_git_info, + lmap, + pickle_save, + save_git_info, + save_json, +) + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +transformers_logging.set_verbosity_info() + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +class GenerativeQAModule(BaseTransformer): + mode = "generative_qa" + loss_names = ["loss"] + metric_names = ["em"] + val_metric = "em" + + def __init__(self, hparams, **kwargs): + # when loading from a pytorch lightning checkpoint, hparams are passed as dict + if isinstance(hparams, dict): + hparams = AttrDict(hparams) + if hparams.model_type == "rag_sequence": + self.model_class = RagSequenceForGeneration + elif hparams.model_type == "rag_token": + self.model_class = RagTokenForGeneration + elif hparams.model_type == "bart": + self.model_class = BartForConditionalGeneration + else: + self.model_class = T5ForConditionalGeneration + self.is_rag_model = is_rag_model(hparams.model_type) + + config_class = RagConfig if self.is_rag_model else AutoConfig + config = config_class.from_pretrained(hparams.model_name_or_path) + + # set extra_model_params for generator configs and load_model + extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "attention_dropout", "dropout") + if self.is_rag_model: + if args.prefix is not None: + config.generator.prefix = args.prefix + config.label_smoothing = hparams.label_smoothing + hparams, config.generator = set_extra_model_params(extra_model_params, hparams, config.generator) + retriever = RagPyTorchDistributedRetriever.from_pretrained(hparams.model_name_or_path) + model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config, retriever=retriever) + prefix = config.question_encoder.prefix + else: + if args.prefix is not None: + config.prefix = args.prefix + hparams, config = set_extra_model_params(extra_model_params, hparams, config) + model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config) + prefix = config.prefix + + tokenizer = ( + RagTokenizer.from_pretrained(hparams.model_name_or_path) + if self.is_rag_model + else AutoTokenizer.from_pretrained(hparams.model_name_or_path) + ) + + super().__init__(hparams, config=config, tokenizer=tokenizer, model=model) + + save_git_info(self.hparams.output_dir) + self.output_dir = Path(self.hparams.output_dir) + self.metrics_save_path = Path(self.output_dir) / "metrics.json" + self.hparams_save_path = Path(self.output_dir) / "hparams.pkl" + pickle_save(self.hparams, self.hparams_save_path) + self.step_count = 0 + self.metrics = defaultdict(list) + + self.dataset_kwargs: dict = dict( + data_dir=self.hparams.data_dir, + max_source_length=self.hparams.max_source_length, + prefix=prefix or "", + ) + n_observations_per_split = { + "train": self.hparams.n_train, + "val": self.hparams.n_val, + "test": self.hparams.n_test, + } + self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()} + + self.target_lens = { + "train": self.hparams.max_target_length, + "val": self.hparams.val_max_target_length, + "test": self.hparams.test_max_target_length, + } + assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}" + assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}" + + self.hparams.git_sha = get_git_info()["repo_sha"] + self.num_workers = hparams.num_workers + self.distributed_port = self.hparams.distributed_port + + def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True): + logger.info("Custom init_ddp_connection.") + os.environ["MASTER_PORT"] = str(self.distributed_port) + super().init_ddp_connection(global_rank, world_size, is_slurm_managing_tasks) + if self.is_rag_model: + self.model.retriever.init_retrieval(self.distributed_port) + + def forward(self, input_ids, **kwargs): + return self.model(input_ids, **kwargs) + + def ids_to_clean_text(self, generated_ids: List[int]): + gen_text = self.tokenizer.batch_decode( + generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True + ) + return lmap(str.strip, gen_text) + + def _step(self, batch: dict) -> Tuple: + source_ids, source_mask, target_ids = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"] + + rag_kwargs = {} + if isinstance(self.model, T5ForConditionalGeneration): + decoder_input_ids = self.model._shift_right(target_ids) + lm_labels = target_ids + elif isinstance(self.model, BartForConditionalGeneration): + decoder_input_ids = target_ids[:, :-1].contiguous() + lm_labels = target_ids[:, 1:].clone() + else: + assert self.is_rag_model + generator = self.model.rag.generator + if isinstance(generator, T5ForConditionalGeneration): + decoder_start_token_id = generator.config.decoder_start_token_id + decoder_input_ids = ( + torch.cat( + [torch.Tensor([[decoder_start_token_id]] * target_ids.shape[0]).to(target_ids), target_ids], + dim=1, + ) + if target_ids.shape[0] < self.target_lens["train"] + else generator._shift_right(target_ids) + ) + elif isinstance(generator, BartForConditionalGeneration): + decoder_input_ids = target_ids + lm_labels = decoder_input_ids + rag_kwargs["reduce_loss"] = True + + assert decoder_input_ids is not None + + outputs = self( + source_ids, + attention_mask=source_mask, + decoder_input_ids=decoder_input_ids, + use_cache=False, + labels=lm_labels, + return_dict=True, + **rag_kwargs, + ) + + loss = outputs["loss"] + return (loss,) + + @property + def pad(self) -> int: + raise NotImplementedError("pad not implemented") + + def training_step(self, batch, batch_idx) -> Dict: + loss_tensors = self._step(batch) + + logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} + # tokens per batch + tgt_pad_token_id = ( + self.tokenizer.generator.pad_token_id + if isinstance(self.tokenizer, RagTokenizer) + else self.tokenizer.pad_token_id + ) + src_pad_token_id = ( + self.tokenizer.question_encoder.pad_token_id + if isinstance(self.tokenizer, RagTokenizer) + else self.tokenizer.pad_token_id + ) + logs["tpb"] = ( + batch["input_ids"].ne(src_pad_token_id).sum() + batch["decoder_input_ids"].ne(tgt_pad_token_id).sum() + ) + + return {"loss": loss_tensors[0], "log": logs} + + def validation_step(self, batch, batch_idx) -> Dict: + return self._generative_step(batch) + + def validation_epoch_end(self, outputs, prefix="val") -> Dict: + self.step_count += 1 + losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names} + loss = losses["loss"] + gen_metrics = { + k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"] + } + metrics_tensor: torch.FloatTensor = torch.tensor(gen_metrics[self.val_metric]).type_as(loss) + gen_metrics.update({k: v.item() for k, v in losses.items()}) + + # fix for https://github.com/PyTorchLightning/pytorch-lightning/issues/2424 + if dist.is_initialized(): + dist.all_reduce(metrics_tensor, op=dist.ReduceOp.SUM) + metrics_tensor = metrics_tensor / dist.get_world_size() + gen_metrics.update({self.val_metric: metrics_tensor.item()}) + + losses.update(gen_metrics) + metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()} + metrics["step_count"] = self.step_count + self.save_metrics(metrics, prefix) # writes to self.metrics_save_path + preds = flatten_list([x["preds"] for x in outputs]) + return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_{self.val_metric}": metrics_tensor} + + def save_metrics(self, latest_metrics, type_path) -> None: + self.metrics[type_path].append(latest_metrics) + save_json(self.metrics, self.metrics_save_path) + + def calc_generative_metrics(self, preds, target) -> Dict: + return calculate_exact_match(preds, target) + + def _generative_step(self, batch: dict) -> dict: + start_time = time.time() + generated_ids = self.model.generate( + batch["input_ids"], + do_deduplication=False, # rag specific parameter + use_cache=True, + min_length=1, + max_length=self.target_lens["val"], + ) + + gen_time = (time.time() - start_time) / batch["input_ids"].shape[0] + preds: List[str] = self.ids_to_clean_text(generated_ids) + target: List[str] = self.ids_to_clean_text(batch["decoder_input_ids"]) + loss_tensors = self._step(batch) + base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} + gen_metrics: Dict = self.calc_generative_metrics(preds, target) + + summ_len = np.mean(lmap(len, generated_ids)) + base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **gen_metrics) + return base_metrics + + def test_step(self, batch, batch_idx): + return self._generative_step(batch) + + def test_epoch_end(self, outputs): + return self.validation_epoch_end(outputs, prefix="test") + + def get_dataset(self, type_path) -> Seq2SeqDataset: + n_obs = self.n_obs[type_path] + max_target_length = self.target_lens[type_path] + dataset = Seq2SeqDataset( + self.tokenizer, + type_path=type_path, + n_obs=n_obs, + max_target_length=max_target_length, + **self.dataset_kwargs, + ) + return dataset + + def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader: + dataset = self.get_dataset(type_path) + sampler = None + if self.hparams.sortish_sampler and type_path == "train": + assert self.hparams.gpus <= 1 # TODO: assert earlier + sampler = dataset.make_sortish_sampler(batch_size) + shuffle = False + + dataloader = DataLoader( + dataset, + batch_size=batch_size, + collate_fn=dataset.collate_fn, + shuffle=shuffle, + num_workers=self.num_workers, + sampler=sampler, + ) + return dataloader + + def train_dataloader(self) -> DataLoader: + dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True) + t_total = ( + (len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus))) + // self.hparams.accumulate_grad_batches + * float(self.hparams.max_epochs) + ) + scheduler = get_linear_schedule_with_warmup( + self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total + ) + if max(scheduler.get_last_lr()) > 0: + warnings.warn("All learning rates are 0") + self.lr_scheduler = scheduler + return dataloader + + def val_dataloader(self) -> DataLoader: + return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size) + + def test_dataloader(self) -> DataLoader: + return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size) + + @pl.utilities.rank_zero_only + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + save_path = self.output_dir.joinpath("checkpoint{}".format(self.step_count)) + self.model.config.save_step = self.step_count + self.model.save_pretrained(save_path) + self.tokenizer.save_pretrained(save_path) + + @staticmethod + def add_model_specific_args(parser, root_dir): + BaseTransformer.add_model_specific_args(parser, root_dir) + add_generic_args(parser, root_dir) + parser.add_argument( + "--max_source_length", + default=128, + type=int, + help="The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded.", + ) + parser.add_argument( + "--max_target_length", + default=25, + type=int, + help="The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded.", + ) + parser.add_argument( + "--val_max_target_length", + default=25, + type=int, + help="The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded.", + ) + parser.add_argument( + "--test_max_target_length", + default=25, + type=int, + help="The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded.", + ) + parser.add_argument("--sortish_sampler", action="store_true", default=False) + parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default") + parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.") + parser.add_argument("--n_val", type=int, default=-1, required=False, help="# examples. -1 means use all.") + parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.") + parser.add_argument("--label_smoothing", type=float, default=0.0, required=False) + parser.add_argument( + "--prefix", + type=str, + default=None, + help="Prefix added at the beginning of each text, typically used with T5-based models.", + ) + parser.add_argument( + "--early_stopping_patience", + type=int, + default=-1, + required=False, + help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.", + ) + parser.add_argument( + "--distributed-port", type=int, default=-1, required=False, help="Port number for distributed training." + ) + parser.add_argument( + "--model_type", + choices=["rag_sequence", "rag_token", "bart", "t5"], + type=str, + help="RAG model type: sequence or token, if none specified, the type is inferred from the model_name_or_path", + ) + return parser + + +def main(args, model=None) -> GenerativeQAModule: + Path(args.output_dir).mkdir(exist_ok=True) + if model is None: + model: GenerativeQAModule = GenerativeQAModule(args) + + dataset = Path(args.data_dir).name + if ( + args.logger_name == "default" + or args.fast_dev_run + or str(args.output_dir).startswith("/tmp") + or str(args.output_dir).startswith("/var") + ): + logger = True # don't pollute wandb logs unnecessarily + elif args.logger_name == "wandb": + from pytorch_lightning.loggers import WandbLogger + + project = os.environ.get("WANDB_PROJECT", dataset) + logger = WandbLogger(name=model.output_dir.name, project=project) + + elif args.logger_name == "wandb_shared": + from pytorch_lightning.loggers import WandbLogger + + logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}") + + es_callback = ( + get_early_stopping_callback(model.val_metric, args.early_stopping_patience) + if args.early_stopping_patience >= 0 + else False + ) + trainer: pl.Trainer = generic_train( + model, + args, + logging_callback=Seq2SeqLoggingCallback(), + checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric), + early_stopping_callback=es_callback, + logger=logger, + ) + pickle_save(model.hparams, model.output_dir / "hparams.pkl") + + if not args.do_predict: + return model + + model.hparams.test_checkpoint = "" + checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True))) + if checkpoints: + model.hparams.test_checkpoint = checkpoints[-1] + trainer.resume_from_checkpoint = checkpoints[-1] # best checkpoint + trainer.logger.log_hyperparams(model.hparams) + + # test() without a model tests using the best checkpoint automatically + trainer.test() + + return model + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser = pl.Trainer.add_argparse_args(parser) + parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd()) + + args = parser.parse_args() + + main(args) diff --git a/examples/rag/finetune.sh b/examples/rag/finetune.sh new file mode 100755 index 000000000..08a6d0a84 --- /dev/null +++ b/examples/rag/finetune.sh @@ -0,0 +1,34 @@ +# Add parent directory to python path to access lightning_base.py +export PYTHONPATH="../":"${PYTHONPATH}" + +# A sample finetuning run, you need to specify data_dir, output_dir and model_name_or_path +# run ./examples/rag/finetune.sh --help to see all the possible options + +python examples/rag/finetune.py \ + --data_dir $DATA_DIR \ + --output_dir $OUTPUT_DIR \ + --model_name_or_path $MODLE_NAME_OR_PATH \ + --model_type rag_sequence \ + --fp16 \ + --gpus 8 \ + --do_train \ + --do_predict \ + --n_val -1 \ + --val_check_interval 0.25 \ + --train_batch_size 8 \ + --eval_batch_size 1 \ + --max_source_length 128 \ + --max_target_length 25 \ + --val_max_target_length 25 \ + --test_max_target_length 25 \ + --label_smoothing 0.1 \ + --dropout 0.1 \ + --attention_dropout 0.1 \ + --weight_decay 0.001 \ + --adam_epsilon 1e-08 \ + --max_grad_norm 0.1 \ + --lr_scheduler polynomial \ + --learning_rate 3e-05 \ + --num_train_epochs 100 \ + --warmup_steps 500 \ + --gradient_accumulation_steps 1 \ No newline at end of file diff --git a/examples/rag/parse_dpr_relevance_data.py b/examples/rag/parse_dpr_relevance_data.py new file mode 100644 index 000000000..4d8a1e5f4 --- /dev/null +++ b/examples/rag/parse_dpr_relevance_data.py @@ -0,0 +1,47 @@ +""" +This script reads DPR retriever training data and parses each datapoint. We save a line per datapoint. +Each line consists of the query followed by a tab-separated list of Wikipedia page titles constituting +positive contexts for a given query. +""" + +import argparse +import json + +from tqdm import tqdm + + +def main(): + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument( + "--src_path", + type=str, + default="biencoder-nq-dev.json", + help="Path to raw DPR training data", + ) + parser.add_argument( + "--evaluation_set", + type=str, + help="where to store parsed evaluation_set file", + ) + parser.add_argument( + "--gold_data_path", + type=str, + help="where to store parsed gold_data_path file", + ) + args = parser.parse_args() + + with open(args.src_path, "r") as src_file, open(args.evaluation_set, "w") as eval_file, open( + args.gold_data_path, "w" + ) as gold_file: + dpr_records = json.load(src_file) + for dpr_record in tqdm(dpr_records): + question = dpr_record["question"] + contexts = [context["title"] for context in dpr_record["positive_ctxs"]] + eval_file.write(question + "\n") + gold_file.write("\t".join(contexts) + "\n") + + +if __name__ == "__main__": + main() diff --git a/examples/rag/requirements.txt b/examples/rag/requirements.txt new file mode 100644 index 000000000..9f754bf2b --- /dev/null +++ b/examples/rag/requirements.txt @@ -0,0 +1,4 @@ +faiss-cpu >= 1.6.3 +datasets >= 1.0.1 +psutil >= 5.7.0 +torch >= 1.4.0 \ No newline at end of file diff --git a/examples/rag/test_distributed_retriever.py b/examples/rag/test_distributed_retriever.py new file mode 100644 index 000000000..49fe0e89b --- /dev/null +++ b/examples/rag/test_distributed_retriever.py @@ -0,0 +1,156 @@ +import json +import os +import shutil +import sys +import tempfile +import unittest +from unittest import TestCase +from unittest.mock import patch + +import numpy as np +from datasets import Dataset + +import faiss +from transformers.configuration_bart import BartConfig +from transformers.configuration_dpr import DPRConfig +from transformers.configuration_rag import RagConfig +from transformers.file_utils import is_datasets_available, is_faiss_available, is_psutil_available, is_torch_available +from transformers.tokenization_bart import BartTokenizer +from transformers.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES +from transformers.tokenization_dpr import DPRQuestionEncoderTokenizer +from transformers.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES + + +sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip + +from examples.rag.distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip + + +def require_distributed_retrieval(test_case): + """ + Decorator marking a test that requires a set of dependencies necessary for pefrorm retrieval with + :class:`~transformers.RagRetriever`. + + These tests are skipped when respective libraries are not installed. + + """ + if not (is_torch_available() and is_datasets_available() and is_faiss_available() and is_psutil_available()): + test_case = unittest.skip("test requires PyTorch, Datasets, Faiss, psutil")(test_case) + return test_case + + +@require_distributed_retrieval +class RagRetrieverTest(TestCase): + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + self.retrieval_vector_size = 8 + + # DPR tok + vocab_tokens = [ + "[UNK]", + "[CLS]", + "[SEP]", + "[PAD]", + "[MASK]", + "want", + "##want", + "##ed", + "wa", + "un", + "runn", + "##ing", + ",", + "low", + "lowest", + ] + dpr_tokenizer_path = os.path.join(self.tmpdirname, "dpr_tokenizer") + os.makedirs(dpr_tokenizer_path, exist_ok=True) + self.vocab_file = os.path.join(dpr_tokenizer_path, DPR_VOCAB_FILES_NAMES["vocab_file"]) + with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer: + vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) + + # BART tok + vocab = [ + "l", + "o", + "w", + "e", + "r", + "s", + "t", + "i", + "d", + "n", + "\u0120", + "\u0120l", + "\u0120n", + "\u0120lo", + "\u0120low", + "er", + "\u0120lowest", + "\u0120newer", + "\u0120wider", + "", + ] + vocab_tokens = dict(zip(vocab, range(len(vocab)))) + merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""] + self.special_tokens_map = {"unk_token": ""} + + bart_tokenizer_path = os.path.join(self.tmpdirname, "bart_tokenizer") + os.makedirs(bart_tokenizer_path, exist_ok=True) + self.vocab_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["vocab_file"]) + self.merges_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["merges_file"]) + with open(self.vocab_file, "w", encoding="utf-8") as fp: + fp.write(json.dumps(vocab_tokens) + "\n") + with open(self.merges_file, "w", encoding="utf-8") as fp: + fp.write("\n".join(merges)) + + def get_dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer: + return DPRQuestionEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer")) + + def get_bart_tokenizer(self) -> BartTokenizer: + return BartTokenizer.from_pretrained(os.path.join(self.tmpdirname, "bart_tokenizer")) + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + def get_dummy_pytorch_distributed_retriever(self, init_retrieval, port=12345) -> RagPyTorchDistributedRetriever: + dataset = Dataset.from_dict( + { + "id": ["0", "1"], + "text": ["foo", "bar"], + "title": ["Foo", "Bar"], + "embeddings": [np.ones(self.retrieval_vector_size), 2 * np.ones(self.retrieval_vector_size)], + } + ) + dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT) + config = RagConfig( + retrieval_vector_size=self.retrieval_vector_size, + question_encoder=DPRConfig().to_dict(), + generator=BartConfig().to_dict(), + ) + with patch("transformers.retrieval_rag.load_dataset") as mock_load_dataset: + mock_load_dataset.return_value = dataset + retriever = RagPyTorchDistributedRetriever( + config, + question_encoder_tokenizer=self.get_dpr_tokenizer(), + generator_tokenizer=self.get_bart_tokenizer(), + ) + if init_retrieval: + retriever.init_retrieval(port) + return retriever + + def test_pytorch_distributed_retriever_retrieve(self): + n_docs = 1 + retriever = self.get_dummy_pytorch_distributed_retriever(init_retrieval=True) + hidden_states = np.array( + [np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32 + ) + retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs) + self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size)) + self.assertEqual(len(doc_dicts), 2) + self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"]) + self.assertEqual(len(doc_dicts[0]["id"]), n_docs) + self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc + self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc + self.assertListEqual(list(doc_ids), [1, 0]) diff --git a/examples/rag/utils.py b/examples/rag/utils.py new file mode 100644 index 000000000..e17fddb73 --- /dev/null +++ b/examples/rag/utils.py @@ -0,0 +1,187 @@ +import linecache +import re +import string +from collections import Counter +from logging import getLogger +from pathlib import Path +from typing import Dict, List + +import torch +from torch.utils.data import Dataset + +from examples.seq2seq.utils import SortishSampler, trim_batch +from transformers import BartTokenizer, RagTokenizer, T5Tokenizer + + +def encode_line(tokenizer, line, max_length, padding_side, pad_to_max_length=True, return_tensors="pt"): + extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) and not line.startswith(" ") else {} + tokenizer.padding_side = padding_side + return tokenizer( + [line], + max_length=max_length, + padding="max_length" if pad_to_max_length else None, + truncation=True, + return_tensors=return_tensors, + add_special_tokens=True, + **extra_kw, + ) + + +class Seq2SeqDataset(Dataset): + def __init__( + self, + tokenizer, + data_dir, + max_source_length, + max_target_length, + type_path="train", + n_obs=None, + src_lang=None, + tgt_lang=None, + prefix="", + ): + super().__init__() + self.src_file = Path(data_dir).joinpath(type_path + ".source") + self.tgt_file = Path(data_dir).joinpath(type_path + ".target") + self.src_lens = self.get_char_lens(self.src_file) + self.max_source_length = max_source_length + self.max_target_length = max_target_length + assert min(self.src_lens) > 0, f"found empty line in {self.src_file}" + self.tokenizer = tokenizer + self.prefix = prefix + if n_obs is not None: + self.src_lens = self.src_lens[:n_obs] + self.src_lang = src_lang + self.tgt_lang = tgt_lang + + def __len__(self): + return len(self.src_lens) + + def __getitem__(self, index) -> Dict[str, torch.Tensor]: + index = index + 1 # linecache starts at 1 + source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") + tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") + assert source_line, f"empty source line for index {index}" + assert tgt_line, f"empty tgt line for index {index}" + + # Need to add eos token manually for T5 + if isinstance(self.tokenizer, T5Tokenizer): + source_line += self.tokenizer.eos_token + tgt_line += self.tokenizer.eos_token + + # Pad source and target to the right + source_tokenizer = ( + self.tokenizer.question_encoder if isinstance(self.tokenizer, RagTokenizer) else self.tokenizer + ) + target_tokenizer = self.tokenizer.generator if isinstance(self.tokenizer, RagTokenizer) else self.tokenizer + + source_inputs = encode_line(source_tokenizer, source_line, self.max_source_length, "right") + target_inputs = encode_line(target_tokenizer, tgt_line, self.max_target_length, "right") + + source_ids = source_inputs["input_ids"].squeeze() + target_ids = target_inputs["input_ids"].squeeze() + src_mask = source_inputs["attention_mask"].squeeze() + return { + "input_ids": source_ids, + "attention_mask": src_mask, + "decoder_input_ids": target_ids, + } + + @staticmethod + def get_char_lens(data_file): + return [len(x) for x in Path(data_file).open().readlines()] + + def collate_fn(self, batch) -> Dict[str, torch.Tensor]: + input_ids = torch.stack([x["input_ids"] for x in batch]) + masks = torch.stack([x["attention_mask"] for x in batch]) + target_ids = torch.stack([x["decoder_input_ids"] for x in batch]) + tgt_pad_token_id = ( + self.tokenizer.generator.pad_token_id + if isinstance(self.tokenizer, RagTokenizer) + else self.tokenizer.pad_token_id + ) + src_pad_token_id = ( + self.tokenizer.question_encoder.pad_token_id + if isinstance(self.tokenizer, RagTokenizer) + else self.tokenizer.pad_token_id + ) + y = trim_batch(target_ids, tgt_pad_token_id) + source_ids, source_mask = trim_batch(input_ids, src_pad_token_id, attention_mask=masks) + batch = { + "input_ids": source_ids, + "attention_mask": source_mask, + "decoder_input_ids": y, + } + return batch + + def make_sortish_sampler(self, batch_size): + return SortishSampler(self.src_lens, batch_size) + + +logger = getLogger(__name__) + + +def normalize_answer(s): + """Lower text and remove punctuation, articles and extra whitespace.""" + + def remove_articles(text): + return re.sub(r"\b(a|an|the)\b", " ", text) + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def f1_score(prediction, ground_truth): + prediction_tokens = normalize_answer(prediction).split() + ground_truth_tokens = normalize_answer(ground_truth).split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) + return f1 + + +def exact_match_score(prediction, ground_truth): + return normalize_answer(prediction) == normalize_answer(ground_truth) + + +def calculate_exact_match(output_lns: List[str], reference_lns: List[str]) -> Dict: + assert len(output_lns) == len(reference_lns) + em = 0 + for hypo, pred in zip(output_lns, reference_lns): + em += exact_match_score(hypo, pred) + if len(output_lns) > 0: + em /= len(output_lns) + return {"em": em} + + +def is_rag_model(model_prefix): + return model_prefix.startswith("rag") + + +def set_extra_model_params(extra_params, hparams, config): + equivalent_param = {p: p for p in extra_params} + # T5 models don't have `dropout` param, they have `dropout_rate` instead + equivalent_param["dropout"] = "dropout_rate" + for p in extra_params: + if getattr(hparams, p, None): + if not hasattr(config, p) and not hasattr(config, equivalent_param[p]): + logger.info("config doesn't have a `{}` attribute".format(p)) + delattr(hparams, p) + continue + set_p = p if hasattr(config, p) else equivalent_param[p] + setattr(config, set_p, getattr(hparams, p)) + delattr(hparams, p) + return hparams, config diff --git a/examples/requirements.txt b/examples/requirements.txt index 45bb30799..9b4433151 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -8,7 +8,7 @@ tensorflow_datasets pytorch-lightning==0.8.5 matplotlib git-python==1.0.3 -faiss +faiss-cpu streamlit elasticsearch pandas diff --git a/setup.cfg b/setup.cfg index b7d686bbd..a4f685aaa 100644 --- a/setup.cfg +++ b/setup.cfg @@ -10,7 +10,7 @@ known_third_party = datasets elasticsearch fairseq - faiss + faiss-cpu fastprogress fire fugashi diff --git a/setup.py b/setup.py index 9546a2861..33bdb2ad3 100644 --- a/setup.py +++ b/setup.py @@ -89,7 +89,8 @@ extras["onnxruntime"] = ["onnxruntime>=1.4.0", "onnxruntime-tools>=1.4.2"] extras["serving"] = ["pydantic", "uvicorn", "fastapi", "starlette"] extras["all"] = extras["serving"] + ["tensorflow", "torch"] -extras["testing"] = ["pytest", "pytest-xdist", "timeout-decorator", "psutil", "parameterized"] +extras["retrieval"] = ["faiss-cpu", "datasets"] +extras["testing"] = ["pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil"] + extras["retrieval"] # sphinx-rtd-theme==0.5.0 introduced big changes in the style. extras["docs"] = ["recommonmark", "sphinx", "sphinx-markdown-tables", "sphinx-rtd-theme==0.4.3", "sphinx-copybutton"] extras["quality"] = ["black >= 20.8b1", "isort >= 5", "flake8 >= 3.8.3"] diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index eace84309..9ad13bcd4 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -42,6 +42,7 @@ from .configuration_mmbt import MMBTConfig from .configuration_mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig from .configuration_pegasus import PegasusConfig +from .configuration_rag import RagConfig from .configuration_reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig @@ -86,6 +87,7 @@ from .file_utils import ( cached_path, is_apex_available, is_datasets_available, + is_faiss_available, is_psutil_available, is_py3nvml_available, is_tf_available, @@ -140,6 +142,9 @@ from .pipelines import ( pipeline, ) +# Retriever +from .retrieval_rag import RagRetriever + # Tokenizers from .tokenization_albert import AlbertTokenizer from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer @@ -172,6 +177,7 @@ from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFas from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast from .tokenization_pegasus import PegasusTokenizer from .tokenization_phobert import PhobertTokenizer +from .tokenization_rag import RagTokenizer from .tokenization_reformer import ReformerTokenizer from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast @@ -416,6 +422,7 @@ if is_torch_available(): load_tf_weights_in_openai_gpt, ) from .modeling_pegasus import PegasusForConditionalGeneration + from .modeling_rag import RagModel, RagSequenceForGeneration, RagTokenForGeneration from .modeling_reformer import ( REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, ReformerAttention, diff --git a/src/transformers/configuration_auto.py b/src/transformers/configuration_auto.py index 39750904f..7bf581217 100644 --- a/src/transformers/configuration_auto.py +++ b/src/transformers/configuration_auto.py @@ -24,6 +24,7 @@ from .configuration_bert_generation import BertGenerationConfig from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig +from .configuration_dpr import DPR_PRETRAINED_CONFIG_ARCHIVE_MAP, DPRConfig from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig from .configuration_encoder_decoder import EncoderDecoderConfig from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig @@ -38,6 +39,7 @@ from .configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfi from .configuration_mobilebert import MobileBertConfig from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig from .configuration_pegasus import PegasusConfig +from .configuration_rag import RagConfig from .configuration_reformer import ReformerConfig from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig @@ -75,6 +77,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict( FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, + DPR_PRETRAINED_CONFIG_ARCHIVE_MAP, ] for key, value, in pretrained_map.items() ) @@ -110,7 +113,9 @@ CONFIG_MAPPING = OrderedDict( ("encoder-decoder", EncoderDecoderConfig), ("funnel", FunnelConfig), ("lxmert", LxmertConfig), + ("dpr", DPRConfig), ("layoutlm", LayoutLMConfig), + ("rag", RagConfig), ] ) @@ -145,6 +150,8 @@ MODEL_NAMES_MAPPING = OrderedDict( ("funnel", "Funnel Transformer"), ("lxmert", "LXMERT"), ("layoutlm", "LayoutLM"), + ("dpr", "DPR"), + ("rag", "RAG"), ] ) diff --git a/src/transformers/configuration_dpr.py b/src/transformers/configuration_dpr.py index ea6a6e595..b8efb5986 100644 --- a/src/transformers/configuration_dpr.py +++ b/src/transformers/configuration_dpr.py @@ -14,7 +14,7 @@ # limitations under the License. """ DPR model configuration """ -from .configuration_bert import BertConfig +from .configuration_utils import PretrainedConfig from .utils import logging @@ -27,7 +27,7 @@ DPR_PRETRAINED_CONFIG_ARCHIVE_MAP = { } -class DPRConfig(BertConfig): +class DPRConfig(PretrainedConfig): r""" :class:`~transformers.DPRConfig` is the configuration class to store the configuration of a `DPRModel`. @@ -36,12 +36,73 @@ class DPRConfig(BertConfig): It is used to instantiate the components of the DPR model. Args: - projection_dim (:obj:`int`, optional, defaults to 0): + vocab_size (:obj:`int`, `optional`, defaults to 30522): + Vocabulary size of the DPR model. Defines the different tokens that + can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.BertModel`. + hidden_size (:obj:`int`, `optional`, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (:obj:`int`, `optional`, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (:obj:`int`, `optional`, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (:obj:`int`, `optional`, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. + If string, :obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"` and :obj:`"gelu_new"` are supported. + hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): + The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (:obj:`int`, `optional`, defaults to 512): + The maximum sequence length that this model might ever be used with. + Typically set this to something large just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (:obj:`int`, `optional`, defaults to 2): + The vocabulary size of the `token_type_ids` passed into :class:`~transformers.BertModel`. + initializer_range (:obj:`float`, `optional`, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): + The epsilon used by the layer normalization layers. + gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): + If True, use gradient checkpointing to save memory at the expense of slower backward pass. + projection_dim (:obj:`int`, `optional`, defaults to 0): Dimension of the projection for the context and question encoders. If it is set to zero (default), then no projection is done. """ model_type = "dpr" - def __init__(self, projection_dim: int = 0, **kwargs): # projection of the encoders, 0 for no projection - super().__init__(**kwargs) + def __init__( + self, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=0, + gradient_checkpointing=False, + projection_dim: int = 0, + **kwargs + ): + super().__init__(pad_token_id=pad_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.gradient_checkpointing = gradient_checkpointing self.projection_dim = projection_dim diff --git a/src/transformers/configuration_rag.py b/src/transformers/configuration_rag.py new file mode 100644 index 000000000..4b91ae893 --- /dev/null +++ b/src/transformers/configuration_rag.py @@ -0,0 +1,176 @@ +# coding=utf-8 +# Copyright 2020, The RAG Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" RAG model configuration """ + +import copy + +from .configuration_utils import PretrainedConfig +from .file_utils import add_start_docstrings + + +RAG_CONFIG_DOC = r""" + :class:`~transformers.RagConfig` stores the configuration of a `RagModel`. + Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used + to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig` + for more information. + + Args: + title_sep (:obj:`str`, `optional`, defaults to ``" / "``): + Separator inserted between the title and the text of the retrieved document when calling :class:`~transformers.RagRetriever`. + doc_sep (:obj:`str`, `optional`, defaults to ``" // "``): + Separator inserted between the the text of the retrieved document and the original input when calliang :class:`~transformers.RagRetriever`. + n_docs (:obj:`int`, `optional`, defaults to 5): + Number of documents to retrieve. + max_combined_length (:obj:`int`, `optional`, defaults to 300): + Max length of contextualized input returned by :meth:`~transformers.RagRetriever.__call__`. + retrieval_vector_size (:obj:`int`, `optional`, defaults to 768): + Dimensionality of the document embeddings indexed by :class:`~transformers.RagRetriever`. + retrieval_batch_size (:obj:`int`, `optional`, defaults to 8): + Retrieval batch size, defined as the number of queries issues concurrently to the faiss index excapsulated :class:`~transformers.RagRetriever`. + dataset (:obj:`str`, `optional`, defaults to :obj:`"wiki_dpr"`): + A datatset identifier of the indexed dataset on HuggingFace AWS bucket (list all available datasets and ids using :obj:`datasets.list_datasets()`). + dataset_split (:obj:`str`, `optional`, defaults to :obj:`train`) + Which split of the ``dataset`` to load. + index_name (:obj:`str`, `optional`, defaults to :obj:`compressed`) + The index_name of the index associated with the :obj:`dataset`. One can choose between :obj:`legacy`, :obj:`exact` and :obj:`compressed`. + index_path (:obj:`str`, `optional`) + The path to the serialized faiss index on disk. + passages_path: (:obj:`str`, `optional`): + A path to text passages compatible with the faiss index. Required if using :class:`~transformers.retrieval_rag.LegacyIndex` + use_dummy_dataset (:obj:`bool`, `optional`, defaults to ``False``) + Whether to load a "dummy" variant of the dataset specified by :obj:`dataset`. + label_smoothing (:obj:`float`, `optional`, defaults to 0.0): + Only relevant if ``return_loss`` is set to :obj:`True`. Controls the ``epsilon`` parameter value for label smoothing in the loss calculation. + If set to ``0.0``, no label smoothing is performed. + do_marginalize (:obj:`bool`, `optional`, defaults to :obj:`False`): + If :obj:`True`, the logits are marginalized over all documents + by making use of ``torch.nn.functional.log_softmax``. + reduce_loss (:obj:`bool`, `optional`, defaults to :obj:`False`): + If :obj:`True`, the NLL loss is reduced using the ``torch.Tensor.sum`` operation. + do_deduplication (:obj:`bool`, `optional`, defaults to :obj:`True`): + Controls whether we want to deduplicate the generations from different context documents for a given input. + Has to be set to :obj:`False` if used while training with distributed backend. + exclude_bos_score (:obj:`bool`, `optional`, defaults to :obj:`False`): + If :obj:`True`, the score of the BOS token is disregarded when computing + the loss. + output_retrieved(:obj:`bool`, `optional`, defaults to :obj:`False`): + If set to ``True``, :obj:`retrieved_doc_embeds`, :obj:`retrieved_doc_ids`, :obj:`context_input_ids` and :obj:`context_attention_mask` are returned. See returned tensors for more detail. +""" + + +@add_start_docstrings(RAG_CONFIG_DOC) +class RagConfig(PretrainedConfig): + model_type = "rag" + + def __init__( + self, + vocab_size=None, + is_encoder_decoder=True, + prefix=None, + bos_token_id=None, + pad_token_id=None, + eos_token_id=None, + decoder_start_token_id=None, + title_sep=" / ", + doc_sep=" // ", + n_docs=5, + max_combined_length=300, + retrieval_vector_size=768, + retrieval_batch_size=8, + dataset="wiki_dpr", + dataset_split="train", + index_name="compressed", + index_path=None, + passages_path=None, + use_dummy_dataset=False, + reduce_loss=False, + label_smoothing=0.0, + do_deduplication=True, + exclude_bos_score=False, + do_marginalize=False, + output_retrieved=False, + **kwargs + ): + super().__init__( + bos_token_id=bos_token_id, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + decoder_start_token_id=decoder_start_token_id, + is_encoder_decoder=is_encoder_decoder, + prefix=prefix, + vocab_size=vocab_size, + **kwargs, + ) + assert ( + "question_encoder" in kwargs and "generator" in kwargs + ), "Config has to be initialized with question_encoder and generator config" + question_encoder_config = kwargs.pop("question_encoder") + question_encoder_model_type = question_encoder_config.pop("model_type") + decoder_config = kwargs.pop("generator") + decoder_model_type = decoder_config.pop("model_type") + + from .configuration_auto import AutoConfig + + self.question_encoder = AutoConfig.for_model(question_encoder_model_type, **question_encoder_config) + self.generator = AutoConfig.for_model(decoder_model_type, **decoder_config) + + self.reduce_loss = reduce_loss + self.label_smoothing = label_smoothing + self.exclude_bos_score = exclude_bos_score + self.do_marginalize = do_marginalize + + self.title_sep = title_sep + self.doc_sep = doc_sep + self.n_docs = n_docs + self.max_combined_length = max_combined_length + + self.dataset = dataset + self.dataset_split = dataset_split + self.index_name = index_name + + self.retrieval_vector_size = retrieval_vector_size + self.retrieval_batch_size = retrieval_batch_size + self.passages_path = passages_path + self.index_path = index_path + self.use_dummy_dataset = use_dummy_dataset + + self.output_retrieved = output_retrieved + + self.do_deduplication = do_deduplication + + @classmethod + def from_question_encoder_generator_configs( + cls, question_encoder_config: PretrainedConfig, generator_config: PretrainedConfig, **kwargs + ) -> PretrainedConfig: + r""" + Instantiate a :class:`~transformers.EncoderDecoderConfig` (or a derived class) from a pre-trained encoder model configuration and decoder model configuration. + + Returns: + :class:`EncoderDecoderConfig`: An instance of a configuration object + """ + return cls(question_encoder=question_encoder_config.to_dict(), generator=generator_config.to_dict(), **kwargs) + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default :meth:`~transformers.PretrainedConfig.to_dict`. + + Returns: + :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + output["question_encoder"] = self.question_encoder.to_dict() + output["generator"] = self.generator.to_dict() + output["model_type"] = self.__class__.model_type + return output diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 09997561e..4daac636d 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -69,6 +69,7 @@ try: import datasets # noqa: F401 _datasets_available = True + logger.debug(f"Succesfully imported datasets version {datasets.__version__}") except ImportError: _datasets_available = False @@ -119,6 +120,16 @@ try: except ImportError: _has_apex = False + +try: + import faiss # noqa: F401 + + _faiss_available = True + logger.debug(f"Succesfully imported faiss version {faiss.__version__}") +except ImportError: + _faiss_available = False + + default_cache_path = os.path.join(torch_cache_home, "transformers") @@ -175,6 +186,10 @@ def is_apex_available(): return _has_apex +def is_faiss_available(): + return _faiss_available + + def add_start_docstrings(*docstr): def docstring_decorator(fn): fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") diff --git a/src/transformers/modeling_auto.py b/src/transformers/modeling_auto.py index bba7b2742..04c66285c 100644 --- a/src/transformers/modeling_auto.py +++ b/src/transformers/modeling_auto.py @@ -27,6 +27,7 @@ from .configuration_auto import ( CamembertConfig, CTRLConfig, DistilBertConfig, + DPRConfig, ElectraConfig, EncoderDecoderConfig, FlaubertConfig, @@ -97,6 +98,7 @@ from .modeling_distilbert import ( DistilBertForTokenClassification, DistilBertModel, ) +from .modeling_dpr import DPRQuestionEncoder from .modeling_electra import ( ElectraForMaskedLM, ElectraForMultipleChoice, @@ -148,6 +150,11 @@ from .modeling_mobilebert import ( ) from .modeling_openai import OpenAIGPTLMHeadModel, OpenAIGPTModel from .modeling_pegasus import PegasusForConditionalGeneration +from .modeling_rag import ( # noqa: F401 - need to import all RagModels to be in globals() function + RagModel, + RagSequenceForGeneration, + RagTokenForGeneration, +) from .modeling_reformer import ( ReformerForMaskedLM, ReformerForQuestionAnswering, @@ -224,6 +231,7 @@ MODEL_MAPPING = OrderedDict( (FunnelConfig, FunnelModel), (LxmertConfig, LxmertModel), (BertGenerationConfig, BertGenerationEncoder), + (DPRConfig, DPRQuestionEncoder), ] ) @@ -412,7 +420,6 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict( ] ) - AUTO_MODEL_PRETRAINED_DOCSTRING = r""" The model class to instantiate is selected based on the :obj:`model_type` property of the config object diff --git a/src/transformers/modeling_dpr.py b/src/transformers/modeling_dpr.py index 1ba940bd5..d7e14d4ac 100644 --- a/src/transformers/modeling_dpr.py +++ b/src/transformers/modeling_dpr.py @@ -272,6 +272,7 @@ class DPRPretrainedContextEncoder(PreTrainedModel): config_class = DPRConfig load_tf_weights = None base_model_prefix = "ctx_encoder" + authorized_missing_keys = [r"position_ids"] def init_weights(self): self.ctx_encoder.init_weights() @@ -285,6 +286,7 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel): config_class = DPRConfig load_tf_weights = None base_model_prefix = "question_encoder" + authorized_missing_keys = [r"position_ids"] def init_weights(self): self.question_encoder.init_weights() @@ -298,6 +300,7 @@ class DPRPretrainedReader(PreTrainedModel): config_class = DPRConfig load_tf_weights = None base_model_prefix = "span_predictor" + authorized_missing_keys = [r"position_ids"] def init_weights(self): self.span_predictor.encoder.init_weights() diff --git a/src/transformers/modeling_rag.py b/src/transformers/modeling_rag.py new file mode 100644 index 000000000..da89004db --- /dev/null +++ b/src/transformers/modeling_rag.py @@ -0,0 +1,1394 @@ +# coding=utf-8 +# Copyright 2020, The RAG Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""RAG model implementation.""" + +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import torch + +from .configuration_rag import RagConfig +from .configuration_utils import PretrainedConfig +from .file_utils import add_start_docstrings_to_callable, replace_return_docstrings +from .modeling_outputs import ModelOutput +from .modeling_utils import PreTrainedModel +from .retrieval_rag import RagRetriever +from .utils import logging + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "RagConfig" + + +@dataclass +class RetrievAugLMMarginOutput(ModelOutput): + """ + Base class for retriever augmented marginalized models outputs. + + Args: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): + Languaged modeling loss. + logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head. + The score is possibly marginalized over all documents for each vocabulary token. + doc_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.n_docs)`): + Score between each retrieved document embeddigs + (see :obj:`retrieved_doc_embeds`) and :obj:`question_encoder_last_hidden_state`. + past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape + :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) + of the decoder that can be used (see ``past_key_values`` input) to + speed up sequential decoding. + retrieved_doc_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.n_docs, hidden_size)`, `optional`, returned when `output_retrieved=True`): + Embedded documents retrieved by the retriever. + Is used with ``question_encoder_last_hidden_state`` to compute + the ``doc_scores``. + retrieved_doc_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, config.n_docs)`, `optional`, returned when `output_retrieved=True`): + The indexes of the embedded documents retrieved by the retriever. + context_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * config.n_docs, config.max_combined_length)`, `optional`, returned when `output_retrieved=True`): + Input ids post-processed from the retrieved documents + and the question encoder input_ids by the retriever. + context_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size * config.n_docs, config.max_combined_length)`, `optional`, returned when `output_retrieved=True`): + Attention mask post-processed from the retrieved documents + and the question encoder input_ids by the retriever. + question_encoder_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer + of the question encoder pooled output of the model. + question_enc_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the question encoder at the output of each layer plus the initial embedding outputs. + question_enc_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights of the question encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + generator_enc_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the generator encoder of the model. + generator_enc_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the generator encoder at the output of each layer plus the initial embedding outputs. + generator_enc_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights of the generator encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + generator_dec_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the generator decoder at the output of each layer plus the initial embedding outputs. + generator_dec_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + doc_scores: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + retrieved_doc_embeds: Optional[torch.FloatTensor] = None + retrieved_doc_ids: Optional[torch.LongTensor] = None + context_input_ids: Optional[torch.LongTensor] = None + context_attention_mask: Optional[torch.LongTensor] = None + question_encoder_last_hidden_state: Optional[torch.FloatTensor] = None + question_enc_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + question_enc_attentions: Optional[Tuple[torch.FloatTensor]] = None + generator_enc_last_hidden_state: Optional[torch.FloatTensor] = None + generator_enc_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + generator_enc_attentions: Optional[Tuple[torch.FloatTensor]] = None + generator_dec_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + generator_dec_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class RetrievAugLMOutput(ModelOutput): + """ + Args: + logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head. + The score is possibly marginalized over all documents for each vocabulary token. + doc_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.n_docs)`): + Score between each retrieved document embeddigs (see :obj:`retrieved_doc_embeds`) and :obj:`question_encoder_last_hidden_state`. + past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, + with each tensor of shape + :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`). + Contains pre-computed hidden-states (key and values in the attention blocks) + of the decoder that can be used (see ``past_key_values`` input) to + speed up sequential decoding. + retrieved_doc_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.n_docs, hidden_size)`, `optional`, returned when `output_retrieved=True`): + Embedded documents retrieved by the retriever. + Is used with ``question_encoder_last_hidden_state`` to compute the ``doc_scores``. + retrieved_doc_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, config.n_docs)`, `optional`, returned when `output_retrieved=True`): + The indexes of the embedded documents retrieved by the retriever. + context_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * config.n_docs, config.max_combined_length)`, `optional`, returned when `output_retrieved=True`): + Input ids post-processed from the retrieved documents + and the question encoder input_ids by the retriever. + context_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size * config.n_docs, config.max_combined_length)`, `optional`, returned when `output_retrieved=True`): + Attention mask post-processed from the retrieved + documents and the question encoder input_ids by the retriever. + question_encoder_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer + of the question encoder pooled output of the model. + question_enc_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the question encoder at the output of each + layer plus the initial embedding outputs. + question_enc_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights of the question encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + generator_enc_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the generator encoder of the model. + generator_enc_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the generator encoder at the output + of each layer plus the initial embedding outputs. + generator_enc_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights of the generator encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + generator_dec_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the generator decoder at the output of each layer plus the initial embedding outputs. + generator_dec_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + logits: torch.FloatTensor = None + doc_scores: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + retrieved_doc_embeds: Optional[torch.FloatTensor] = None + retrieved_doc_ids: Optional[torch.LongTensor] = None + context_input_ids: Optional[torch.LongTensor] = None + context_attention_mask: Optional[torch.LongTensor] = None + question_encoder_last_hidden_state: Optional[torch.FloatTensor] = None + question_enc_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + question_enc_attentions: Optional[Tuple[torch.FloatTensor]] = None + generator_enc_last_hidden_state: Optional[torch.FloatTensor] = None + generator_enc_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + generator_enc_attentions: Optional[Tuple[torch.FloatTensor]] = None + generator_dec_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + generator_dec_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class RagPreTrainedModel(PreTrainedModel): + r""" + RAG models were released with the paper `Retrieval-Augmented Generation for + Knowledge-Intensive NLP Tasks `_ by Patrick Lewis, Ethan Perez, Aleksandra Piktus et al. + + RAG is a retriever augmented model and encapsulate three components: a question encoder, a dataset retriever and a generator, the encoder and generator are trainable while the retriever is just an indexed dataset. + + """ + config_class = RagConfig + base_model_prefix = "rag" + authorized_missing_keys = [r"position_ids"] + + @classmethod + def from_pretrained_question_encoder_generator( + cls, + question_encoder_pretrained_model_name_or_path: str = None, + generator_pretrained_model_name_or_path: str = None, + retriever: RagRetriever = None, + *model_args, + **kwargs + ) -> PreTrainedModel: + r"""Instantiates an question_encoder and a generator from one or two base classes of the library from pre-trained model checkpoints. + + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). + To train the model, you need to first set it back in training mode with `model.train()`. + + Params: + question_encoder_pretrained_model_name_or_path (:obj: `str`, `optional`, defaults to `None`): + information necessary to initiate the question_encoder. Either: + + - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. + - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. + - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/question_encoder``. + - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + + generator_pretrained_model_name_or_path (:obj: `str`, `optional`, defaults to `None`): + information necessary to initiate the generator. Either: + + - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. + - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. + - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/generator``. + - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + + model_args: (`optional`) Sequence of positional arguments: + All remaning positional arguments will be passed to the underlying model's ``__init__`` method + + retriever: (`optional`, ``RagRetriever``) An instance of a :class:`~transformers.RagRetriever` to use as a retriever. + + kwargs: (`optional`) Remaining dictionary of keyword arguments. + Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attentions=True``). + - To update the question_encoder configuration, use the prefix `question_encoder_` for each configuration parameter + - To update the generator configuration, use the prefix `generator_` for each configuration parameter + - To update the parent model configuration, do not use a prefix for each configuration parameter + Behave differently depending on whether a :obj:`config` is provided or automatically loaded. + + Example:: + + >>> from transformers import RagModel + >>> # initialize a RAG from two pretrained models. + >>> model = RagModel.from_question_encoder_generator_pretrained('facebook/dpr-question_encoder-single-nq-base', 't5-small') + >>> # saving model after fine-tuning + >>> model.save_pretrained("./rag") + >>> # load fine-tuned model + >>> model = RagModel.from_pretrained("./rag") + + """ + + kwargs_question_encoder = { + argument[len("question_question_encoder_") :]: value + for argument, value in kwargs.items() + if argument.startswith("question_encoder_") + } + + kwargs_generator = { + argument[len("generator_") :]: value + for argument, value in kwargs.items() + if argument.startswith("generator_") + } + + # remove question_encoder, generator kwargs from kwargs + for key in kwargs_question_encoder.keys(): + del kwargs["question_encoder_" + key] + for key in kwargs_generator.keys(): + del kwargs["generator_" + key] + + # Load and initialize the question_encoder and generator + # The distinction between question_encoder and generator at the model level is made + # by the value of the flag `is_generator` that we need to set correctly. + question_encoder = kwargs_question_encoder.pop("model", None) + if question_encoder is None: + assert ( + question_encoder_pretrained_model_name_or_path is not None + ), "If `model` is not defined as an argument, a `question_encoder_pretrained_model_name_or_path` has to be defined" + from .modeling_auto import AutoModel + + if "config" not in kwargs_question_encoder: + from .configuration_auto import AutoConfig + + question_encoder_config = AutoConfig.from_pretrained(question_encoder_pretrained_model_name_or_path) + kwargs_question_encoder["config"] = question_encoder_config + + question_encoder = AutoModel.from_pretrained( + question_encoder_pretrained_model_name_or_path, *model_args, **kwargs_question_encoder + ) + + generator = kwargs_generator.pop("model", None) + if generator is None: + assert ( + generator_pretrained_model_name_or_path is not None + ), "If `generator_model` is not defined as an argument, a `generator_pretrained_model_name_or_path` has to be defined" + from .modeling_auto import AutoModelForSeq2SeqLM + + if "config" not in kwargs_generator: + from .configuration_auto import AutoConfig + + generator_config = AutoConfig.from_pretrained(generator_pretrained_model_name_or_path) + kwargs_generator["config"] = generator_config + + generator = AutoModelForSeq2SeqLM.from_pretrained( + generator_pretrained_model_name_or_path, **kwargs_generator + ) + + # instantiate config with corresponding kwargs + config = kwargs.get("config", None) + if config is None: + config = RagConfig.from_question_encoder_generator_configs( + question_encoder.config, generator.config, **kwargs + ) + + return cls(question_encoder=question_encoder, generator=generator, config=config, retriever=retriever) + + +RAG_START_DOCSTRING = r""" + RAG is a seq2seq model which encapsulates two core components: a question encoder and a generator. + During a forward pass, we encode the input with the question encoder and pass it + to the retriever to extract relevant context documents. The documents are then prepended to the input. + Such contextualized inputs is passed to the generator. + + The question encoder can be any `autoencoding` model, preferably :obj:`~transformers.DPRQuestionEncoder`, and the generator can be any `seq2seq` model, preferably :obj:`~transformers.BartForConditionalGeneration`. + + The model can be initialized with a :obj:`~transformers.RagRetriever` for end-to-end generation or used in combination with the outputs of a retriever in multiple steps - see examples for more details. + The model is compatible any `autoencoding` model as the ``question_encoder`` and any `seq2seq` model with language model head as the ``generator``. + The model has been tested with :class:`~transformers.DPRQuestionEncoder` as the ``question_encoder`` and :class:`~transformers.BartForConditionalGeneration` or :class:`~transformers.T5ForConditionalGeneration` as the ``generator``. + + This model is a PyTorch `torch.nn.Module `_ sub-class. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general + usage and behavior. + + Args: + config (:class:`~transformers.RagConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the configuration. + Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. + question_encoder (:class:`transformers.PreTrainedModel`): + An encoder model compatible with the faiss index encapsulated by the ``retriever``. + generator (:class:`transformers.PreTrainedModel`): + A seq2seq model used as the generator in the RAG architecture. + retriever (:class:`~transformers.RagRetriever`): + A retriever class encapsulating a faiss index queried to obtain context documents for current inputs. +""" + + +RAG_FORWARD_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + :class:`~transformers.RagConfig`, used to initialize the model, specifies which generator to use, it also specifies a compatible + generator tokenizer. Use that tokenizer class to obtain the indices. + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Mask to avoid performing attention on padding token indices in input_ids. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`) + Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`) + `last_hidden_state` of shape :obj:`(batch_size, n_docs * sequence_length, hidden_size)` is a sequence of hidden-states at the output of the last layer of the encoder. + `doc_scores` of shape :obj:`(batch_size, n_docs)` store retrieval scores of documents retrieved for each input in the batch. + Used by the (:class:`~transformers.RagTokenForGeneration`) model during decoding. + decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`): + Provide for generation tasks. `None` by default, constuct as per instructions for the generator model you're using with your RAG instance. + Provide for generation tasks. `None` by default, constuct as per instructions for the generator model you're using with your RAG instance. + decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): + Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))`): + Tuple consists of two elements: :obj:`encoder_outputs` of the RAG model (see :obj:`encoder_outputs`) and :obj:`past_key_values` of the underlying generator. + Can be used to speed up decoding. :obj:`past_key_values` are used in the (:class:`~transformers.RagTokenForGeneration`) + model during decoding. + doc_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.n_docs)`): + Score between each retrieved document embeddigs (see :obj:`retrieved_doc_embeds`) and :obj:`question_encoder_last_hidden_state`. + If the model has is not initialized with a ``retriever`` :obj:`doc_scores` has to be provided to the forward pass. :obj:`doc_scores` can be computed via :obj:`question_encoder_last_hidden_state` and :obj:`retrieved_doc_embeds`, see examples for more information. + context_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * config.n_docs, config.max_combined_length)`, `optional`, returned when `output_retrieved=True`): + Input ids post-processed from the retrieved documents and the question encoder input_ids by the retriever. + If the model has is not initialized with a ``retriever`` :obj:`context_input_ids` has to be provided to the forward pass. :obj:`context_input_ids` are returned by :meth:`~transformers.RagRetriever.__call__` + context_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size * config.n_docs, config.max_combined_length)`, `optional`, returned when `output_retrieved=True`): + Attention mask post-processed from the retrieved documents and the question encoder input_ids by the retriever. + If the model has is not initialized with a ``retriever`` :obj:`context_attention_mask` has to be provided to the forward pass. :obj:`context_attention_mask` are returned by :meth:`~transformers.RagRetriever.__call__` + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): + If `use_cache` is True, ``past_key_values`` are returned and can be used to speed up decoding (see + ``past_key_values``). + output_attentions (:obj:`bool`, `optional`): + If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail. + output_retrieved(:obj:`bool`, `optional`): + If set to ``True``, :obj:`retrieved_doc_embeds`, :obj:`retrieved_doc_ids`, :obj:`context_input_ids` and :obj:`context_attention_mask` are returned. See returned tensors for more detail. +""" + + +@add_start_docstrings_to_callable(RAG_START_DOCSTRING) +class RagModel(RagPreTrainedModel): + def __init__( + self, + config: Optional[PretrainedConfig] = None, + question_encoder: Optional[PreTrainedModel] = None, + generator: Optional[PreTrainedModel] = None, + retriever: Optional = None, # or maybe just use a `set_retriever(...)` method + **kwargs, + ): + assert config is not None or ( + question_encoder is not None and generator is not None + ), "Either a configuration or an question_encoder and a generator has to be provided." + + if config is None: + config = RagConfig.from_question_encoder_generator_configs( + question_encoder.config, generator.config, **kwargs + ) + else: + assert isinstance(config, self.config_class), "config: {} has to be of type {}".format( + config, self.config_class + ) + super().__init__(config) + if question_encoder is None: + from .modeling_auto import AutoModel + + question_encoder = AutoModel.from_config(config.question_encoder) + + if generator is None: + from .modeling_auto import AutoModelForSeq2SeqLM + + generator = AutoModelForSeq2SeqLM.from_config(config.generator) + + self.retriever = retriever + if self.retriever is not None: + assert isinstance( + retriever, RagRetriever + ), f"`self.retriever` is of type {type(self.retriever)}, but should be of type `RagRetriever`" + self.retriever = retriever + + self.question_encoder = question_encoder + self.generator = generator + + @add_start_docstrings_to_callable(RAG_FORWARD_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=RetrievAugLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_outputs=None, + decoder_input_ids=None, + decoder_attention_mask=None, + past_key_values=None, + doc_scores=None, + context_input_ids=None, + context_attention_mask=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + output_retrieved=None, + ): + r""" + Returns: + + Example:: + + >>> from transformers import RagTokenizer, RagRetriever, RagModel + >>> import torch + + >>> tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base") + >>> retriever = RagRetriever.from_pretrained("facebook/rag-token-base", index_name="exact", use_dummy_dataset=True) + >>> # initialize with RagRetriever to do everything in one forward call + >>> model = RagModel.from_pretrained("facebook/rag-token-base", retriever=retriever) + + >>> input_dict = tokenizer.prepare_seq2seq_batch("How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="pt") + >>> input_ids = input_dict["input_ids"] + >>> outputs = model(input_ids=input_ids, labels=input_dict["labels"]) + + """ + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_retrieved = output_retrieved if output_retrieved is not None else self.config.output_retrieved + + # whether retriever has to be used + has_to_retrieve = ( + self.retriever is not None + and (context_input_ids is None or context_attention_mask is None or doc_scores is None) + and encoder_outputs is None + ) + # encoder_outputs are pre-computed during RAG-token generation + if encoder_outputs is None: + + if has_to_retrieve: + question_enc_outputs = self.question_encoder( + input_ids, attention_mask=attention_mask, return_dict=True + ) + question_encoder_last_hidden_state = question_enc_outputs[0] # hidden states of question encoder + + retriever_outputs = self.retriever( + input_ids, + question_encoder_last_hidden_state.cpu().detach().to(torch.float32).numpy(), + prefix=self.generator.config.prefix, + n_docs=self.config.n_docs, + return_tensors="pt", + ) + context_input_ids, context_attention_mask, retrieved_doc_embeds, retrieved_doc_ids = ( + retriever_outputs["context_input_ids"], + retriever_outputs["context_attention_mask"], + retriever_outputs["retrieved_doc_embeds"], + retriever_outputs["doc_ids"], + ) + + # set to correct device + retrieved_doc_embeds = retrieved_doc_embeds.to(question_encoder_last_hidden_state) + context_input_ids = context_input_ids.to(input_ids) + context_attention_mask = context_attention_mask.to(input_ids) + + # compute doc_scores + doc_scores = torch.bmm( + question_encoder_last_hidden_state.unsqueeze(1), retrieved_doc_embeds.transpose(1, 2) + ).squeeze(1) + else: + assert ( + context_input_ids is not None + ), "Make sure that `context_input_ids` are passed, if no `retriever` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function." + assert ( + context_attention_mask is not None + ), "Make sure that `context_attention_mask` are passed, if no `retriever` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function." + assert ( + doc_scores is not None + ), "Make sure that `doc_scores` are passed, if no `retriever` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function." + + assert ( + doc_scores is not None + ), "Make sure that `doc_scores` are passed when passing `encoder_outputs` to the forward function." + + # Decoder input without context documents + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.repeat_interleave(self.config.n_docs, dim=0) + + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.repeat_interleave(self.config.n_docs, dim=0) + + gen_outputs = self.generator( + input_ids=context_input_ids, + attention_mask=context_attention_mask, + encoder_outputs=encoder_outputs, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + return_dict=True, + ) + + if not has_to_retrieve: + question_encoder_last_hidden_state = None + question_enc_hidden_states = None + question_enc_attentions = None + retrieved_doc_embeds = None + retrieved_doc_ids = None + else: + question_enc_hidden_states = question_enc_outputs.hidden_states + question_enc_attentions = question_enc_outputs.attentions + + if not has_to_retrieve or not output_retrieved: + # don't output retrieved docs + context_input_ids = (None,) + context_attention_mask = None + retrieved_doc_embeds = None + retrieved_doc_ids = None + + return RetrievAugLMOutput( + logits=gen_outputs.logits, + doc_scores=doc_scores, + past_key_values=gen_outputs.past_key_values, + context_input_ids=context_input_ids, + context_attention_mask=context_attention_mask, + retrieved_doc_embeds=retrieved_doc_embeds, + retrieved_doc_ids=retrieved_doc_ids, + question_encoder_last_hidden_state=question_encoder_last_hidden_state, + question_enc_hidden_states=question_enc_hidden_states, + question_enc_attentions=question_enc_attentions, + generator_enc_last_hidden_state=gen_outputs.encoder_last_hidden_state, + generator_enc_hidden_states=gen_outputs.encoder_hidden_states, + generator_enc_attentions=gen_outputs.encoder_attentions, + generator_dec_hidden_states=gen_outputs.decoder_hidden_states, + generator_dec_attentions=gen_outputs.decoder_attentions, + ) + + +@add_start_docstrings_to_callable( + """A RAG-sequence model impementation. It performs RAG-sequence specific marginalization in the forward pass. + """, + RAG_START_DOCSTRING, +) +class RagSequenceForGeneration(RagPreTrainedModel): + def __init__( + self, + config: Optional[PretrainedConfig] = None, + question_encoder: Optional[PreTrainedModel] = None, + generator: Optional[PreTrainedModel] = None, + retriever: Optional = None, + **kwargs, + ): + assert config is not None or ( + question_encoder is not None and generator is not None + ), "Either a configuration or an encoder and a generator has to be provided." + + if config is None: + config = RagConfig.from_encoder_generator_configs(question_encoder.config, generator.config, **kwargs) + super().__init__(config) + + # instantiate model + self.rag = RagModel(config=config, question_encoder=question_encoder, generator=generator, retriever=retriever) + + def set_retriever(self, retriever: RagRetriever): + self.rag.retriever = retriever + + @add_start_docstrings_to_callable(RAG_FORWARD_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=RetrievAugLMMarginOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_outputs=None, + decoder_input_ids=None, + decoder_attention_mask=None, + past_key_values=None, + context_input_ids=None, + context_attention_mask=None, + doc_scores=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + output_retrieved=None, + exclude_bos_score=None, + reduce_loss=None, + labels=None, + **kwargs # needs kwargs for generation + ): + r""" + exclude_bos_score (:obj:`bool`, `optional`): + Only relevant if ``labels`` is passed. + If :obj:`True`, the score of the BOS token is disregarded when computing + the loss. + reduce_loss (:obj:`bool`, `optional`): + Only relevant if ``labels`` is passed. + If :obj:`True`, the NLL loss is reduced using the ``torch.Tensor.sum`` operation. + kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): + Legacy dictionary, which is required so that model can use `generate()` function. + + Returns: + + Example:: + + >>> from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration + >>> import torch + + >>> tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") + >>> retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True) + >>> # initialize with RagRetriever to do everything in one forward call + >>> model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever) + + >>> input_dict = tokenizer.prepare_seq2seq_batch("How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="pt") + >>> input_ids = input_dict["input_ids"] + >>> outputs = model(input_ids=input_ids, labels=input_dict["labels"]) + + >>> # or use retriever seperately + >>> model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", use_dummy_dataset=True) + >>> # 1. Encode + >>> question_hidden_states = model.question_encoder(input_ids)[0] + >>> # 2. Retrieve + >>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.detach().numpy(), return_tensors="pt") + >>> doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)).squeeze(1) + >>> # 3. Forward to generator + >>> outputs = model(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores, decoder_input_ids=input_dict["labels"]) + + >>> # or directly generate + >>> generated = model.generate(input_ids=input_dict["input_ids"]) + >>> generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True) + """ + exclude_bos_score = exclude_bos_score if exclude_bos_score is not None else self.config.exclude_bos_score + reduce_loss = reduce_loss if reduce_loss is not None else self.config.reduce_loss + + if labels is not None: + if decoder_input_ids is None: + decoder_input_ids = labels + use_cache = False + + outputs = self.rag( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_outputs=encoder_outputs, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + context_input_ids=context_input_ids, + context_attention_mask=context_attention_mask, + doc_scores=doc_scores, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_retrieved=output_retrieved, + ) + + loss = None + if labels is not None: + loss = self.get_nll( + outputs.logits, + outputs.doc_scores, + decoder_input_ids, + reduce_loss=reduce_loss, + epsilon=self.config.label_smoothing, + exclude_bos_score=exclude_bos_score, + ) + + return RetrievAugLMMarginOutput( + loss=loss, + logits=outputs.logits, + doc_scores=outputs.doc_scores, + past_key_values=outputs.past_key_values, + context_input_ids=outputs.context_input_ids, + context_attention_mask=outputs.context_attention_mask, + retrieved_doc_embeds=outputs.retrieved_doc_embeds, + retrieved_doc_ids=outputs.retrieved_doc_ids, + question_encoder_last_hidden_state=outputs.question_encoder_last_hidden_state, + question_enc_hidden_states=outputs.question_enc_hidden_states, + question_enc_attentions=outputs.question_enc_attentions, + generator_enc_last_hidden_state=outputs.generator_enc_last_hidden_state, + generator_enc_hidden_states=outputs.generator_enc_hidden_states, + generator_enc_attentions=outputs.generator_enc_attentions, + generator_dec_hidden_states=outputs.generator_dec_hidden_states, + generator_dec_attentions=outputs.generator_dec_attentions, + ) + + @property + def retriever(self): + return self.rag.retriever + + @property + def generator(self): + return self.rag.generator + + @property + def question_encoder(self): + return self.rag.question_encoder + + @torch.no_grad() + def generate( + self, + input_ids, + context_input_ids=None, + do_deduplication=None, # defaults to True + num_return_sequences=None, # defaults to 1 + num_beams=None, # defaults to 1 + **kwargs + ): + """ + Implements RAG sequence "thorough" decoding. + Read the :meth:`~transformers.PreTrainedModel.generate`` documentation for more information on how to set other generate input parameters. + + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + The sequence used as a prompt for the generation. If :obj:`input_ids` is not passed, then :obj:`context_input_ids` has to be provided. + context_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * config.n_docs, config.max_combined_length)`, `optional`, returned when `output_retrieved=True`): + Input ids post-processed from the retrieved documents and the question encoder input_ids by the retriever. + do_deduplication (:obj:`bool`, `optional`): + Controls whether we want to deduplicate the generations from different context documents for a given input. + Has to be set to :obj:`False` if used while training with distributed backend. + num_return_sequences(:obj:`int`, `optional`, defaults to 1): + The number of independently computed returned sequences for each element in the batch. Note that this is not the value + we pass to the ``generator``'s `:func:`~transformers.PreTrainedModel.generate`` function, where we set ``num_return_sequences`` + to `num_beams`. + num_beams (:obj:`int`, `optional`, defaults to 1): + Number of beams for beam search. 1 means no beam search. + kwargs: + Additional kwargs will be passed to :meth:`~transformers.PreTrainedModel.generate``. + Return: + + :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: + The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or + shorter if all batches finished early due to the :obj:`eos_token_id`. + """ + + do_deduplication = do_deduplication if do_deduplication is not None else self.config.do_deduplication + num_doc_return_sequences = ( + num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences + ) + num_beams = num_beams if num_beams is not None else self.config.num_beams + + # TODO(patrick) - clean up generate here + if self.retriever is not None and context_input_ids is None: + question_hidden_states = self.question_encoder(input_ids)[0] + context_input_ids = self.retriever( + input_ids, + question_hidden_states.cpu().detach().to(torch.float32).numpy(), + prefix=self.generator.config.prefix, + n_docs=self.config.n_docs, + return_tensors="pt", + )["context_input_ids"] + + # set to correct device + context_input_ids = context_input_ids.to(input_ids) + + hypos = [] + kwargs["num_beams"] = num_beams + kwargs["num_return_sequences"] = num_return_sequences + kwargs["attention_mask"] = None + + for index in range(len(input_ids)): + # first, generate beams from documents: + generator_input_ids = context_input_ids[ + index * self.config.n_docs : (index + 1) * self.config.n_docs + ] # (n_docs, max_len) + + output_sequences = self.generator.generate( + generator_input_ids, + **kwargs, + ) # n_docs * n_beam, tgt_len + if do_deduplication: + # do_deduplication, max_output_len + output_sequences = torch.stack(list({str(k.tolist()): k for k in output_sequences}.values())) + + # then, run model forwards to get nll scores: + new_input_ids = input_ids[index : index + 1].repeat(len(output_sequences), 1) + outputs = self(new_input_ids, labels=output_sequences, exclude_bos_score=True) + top_cand_inds = (-outputs["loss"]).topk(num_doc_return_sequences)[1] + + # add hypothesis + hypos.append(output_sequences[top_cand_inds]) + + return self._cat_and_pad(hypos, pad_token_id=self.config.generator.pad_token_id) + + def get_nll(self, seq_logits, doc_scores, target, reduce_loss=False, epsilon=0.0, exclude_bos_score=False): + # shift tokens left + target = torch.cat( + [target[:, 1:], target.new(target.shape[0], 1).fill_(self.config.generator.pad_token_id)], 1 + ) + + # bos_token_id is None for T5 + use_bos = self.config.bos_token_id is not None and target[:, 0].eq(self.config.bos_token_id).all() + + def _mask_pads(ll, smooth_obj): + pad_mask = target.eq(self.config.generator.pad_token_id) + if pad_mask.any(): + ll.masked_fill_(pad_mask, 0.0) + smooth_obj.masked_fill_(pad_mask, 0.0) + return ll.squeeze(-1), smooth_obj.squeeze(-1) + + seq_logprobs = torch.nn.functional.log_softmax(seq_logits, dim=-1).view( + seq_logits.shape[0] // self.config.n_docs, self.config.n_docs, -1, seq_logits.size(-1) + ) # batch_size x n_docs x tgt_len x dim + doc_logprobs = torch.nn.functional.log_softmax(doc_scores, dim=1).unsqueeze(-1).unsqueeze(-1) + + # RAG-sequence marginaliation + first_token_scores = seq_logprobs[:, :, :1, :] + second_token_scores = seq_logprobs[:, :, 1:2, :] + remainder = seq_logprobs[:, :, 2:, :] + rag_logprobs = torch.cat([first_token_scores, second_token_scores + doc_logprobs, remainder], dim=2) + + # calcualate loss + target = target.unsqueeze(1).unsqueeze(-1).repeat(1, self.config.n_docs, 1, 1) + assert target.dim() == rag_logprobs.dim() + + ll = rag_logprobs.gather(dim=-1, index=target) + smooth_obj = rag_logprobs.sum(dim=-1, keepdim=True) # total sum of all (normalised) logits + + ll, smooth_obj = _mask_pads(ll, smooth_obj) + + # sum over tokens, exclude bos while scoring + ll = ll[:, :, 1:].sum(2) if exclude_bos_score and use_bos else ll.sum(2) + smooth_obj = smooth_obj.sum(2) + ll = ll.logsumexp(1) # logsumexp over docs + smooth_obj = smooth_obj.logsumexp(1) + + nll_loss = -ll + smooth_loss = -smooth_obj + + if reduce_loss: + nll_loss = nll_loss.sum() + smooth_loss = smooth_loss.sum() + + eps_i = epsilon / rag_logprobs.size(-1) + loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss + return loss + + @staticmethod + def _cat_and_pad(tensors, pad_token_id): + output = ( + tensors[0].new(sum([t.shape[0] for t in tensors]), max([t.shape[1] for t in tensors])).fill_(pad_token_id) + ) + ind = 0 + for t in tensors: + output[ind : ind + t.shape[0], : t.shape[1]] = t + ind += t.shape[0] + return output + + +@add_start_docstrings_to_callable( + """A RAG-token model impementation. It performs RAG-token specific marginalization in the forward pass. + """, + RAG_START_DOCSTRING, +) +class RagTokenForGeneration(RagPreTrainedModel): + def __init__( + self, + config: Optional[PretrainedConfig] = None, + question_encoder: Optional[PreTrainedModel] = None, + generator: Optional[PreTrainedModel] = None, + retriever: Optional = None, + **kwargs, + ): + assert config is not None or ( + question_encoder is not None and generator is not None + ), "Either a configuration or an encoder and a generator has to be provided." + + if config is None: + config = RagConfig.from_encoder_generator_configs(question_encoder.config, generator.config, **kwargs) + + super().__init__(config) + + # instantiate model + self.rag = RagModel(config=config, question_encoder=question_encoder, generator=generator, retriever=retriever) + + def set_retriever(self, retriever: RagRetriever): + self.rag.retriever = retriever + + def adjust_logits_during_generation(self, logits, cur_len, max_length): + return self.rag.generator.adjust_logits_during_generation(logits, cur_len=cur_len, max_length=max_length) + + def prepare_inputs_for_generation( + self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, doc_scores, **kwargs + ): + return { + "input_ids": None, + "encoder_outputs": encoder_outputs, + "doc_scores": doc_scores, + "context_attention_mask": attention_mask, + "decoder_input_ids": decoder_input_ids, + "past_key_values": past, + "use_cache": use_cache, + "do_marginalize": True, + } + + @property + def retriever(self): + return self.rag.retriever + + @property + def generator(self): + return self.rag.generator + + @property + def question_encoder(self): + return self.rag.question_encoder + + @staticmethod + def _reorder_cache(past, beam_idx): + """Reorders cache for generation. BART-inspired but we need to take care of the extra dimension for docs""" + + def _reorder_stacked(hidden_states): + n_docs = hidden_states.shape[0] // beam_idx.shape[0] + hidden_states = hidden_states.view(-1, n_docs, *hidden_states.shape[1:]) + hidden_states = hidden_states.index_select(0, beam_idx) + return hidden_states.view(-1, *hidden_states.shape[2:]) + + def _reorder_buffer(attn_cache): + for k, input_buffer_k in attn_cache.items(): + if input_buffer_k is not None: + attn_cache[k] = _reorder_stacked(input_buffer_k) + return attn_cache + + reordered_past = [] + for layer_past in past: + # get the correct batch idx from decoder layer's batch dim for cross and self-attn + layer_past_new = {attn_key: _reorder_buffer(attn_cache) for attn_key, attn_cache in layer_past.items()} + reordered_past.append(layer_past_new) + + return reordered_past + + def marginalize(self, seq_logits, doc_scores): + # RAG-token marginalization + seq_logprobs = torch.nn.functional.log_softmax(seq_logits, dim=-1).view( + seq_logits.shape[0] // self.config.n_docs, self.config.n_docs, -1, seq_logits.size(-1) + ) + doc_logprobs = torch.log_softmax(doc_scores, dim=1) + log_prob_sum = seq_logprobs + doc_logprobs.unsqueeze(-1).unsqueeze(-1) + return torch.logsumexp(log_prob_sum, dim=1) + + @add_start_docstrings_to_callable(RAG_FORWARD_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=RetrievAugLMMarginOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_outputs=None, + decoder_input_ids=None, + decoder_attention_mask=None, + past_key_values=None, + context_input_ids=None, + context_attention_mask=None, + doc_scores=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + output_retrieved=None, + do_marginalize=None, + reduce_loss=None, + labels=None, + **kwargs # needs kwargs for generation + ): + r""" + do_marginalize (:obj:`bool`, `optional`): + If :obj:`True`, the logits are marginalized over all documents + by making use of ``torch.nn.functional.log_softmax``. + reduce_loss (:obj:`bool`, `optional`): + Only relevant if ``labels`` is passed. + If :obj:`True`, the NLL loss is reduced using the ``torch.Tensor.sum`` operation. + kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): + Legacy dictionary, which is required so that model can use `generate()` function. + Returns: + + Example:: + + >>> from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration + >>> import torch + + >>> tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") + >>> retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True) + >>> # initialize with RagRetriever to do everything in one forward call + >>> model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever) + + >>> input_dict = tokenizer.prepare_seq2seq_batch("How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="pt") + >>> input_ids = input_dict["input_ids"] + >>> outputs = model(input_ids=input_ids, labels=input_dict["labels"]) + + >>> # or use retriever seperately + >>> model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", use_dummy_dataset=True) + >>> # 1. Encode + >>> question_hidden_states = model.question_encoder(input_ids)[0] + >>> # 2. Retrieve + >>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.detach().numpy(), return_tensors="pt") + >>> doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)).squeeze(1) + >>> # 3. Forward to generator + >>> outputs = model(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores, decoder_input_ids=input_dict["labels"]) + + >>> # or directly generate + >>> generated = model.generate(input_ids=input_dict["input_ids"]) + >>> generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True) + """ + do_marginalize = do_marginalize if do_marginalize is not None else self.config.do_marginalize + reduce_loss = reduce_loss if reduce_loss is not None else self.config.reduce_loss + + if labels is not None: + if decoder_input_ids is None: + decoder_input_ids = labels + use_cache = False + + outputs = self.rag( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_outputs=encoder_outputs, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + context_input_ids=context_input_ids, + context_attention_mask=context_attention_mask, + doc_scores=doc_scores, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_retrieved=output_retrieved, + ) + + loss = None + logits = outputs.logits + if labels is not None: + assert decoder_input_ids is not None + loss = self.get_nll( + outputs.logits, + outputs.doc_scores, + labels, + reduce_loss=reduce_loss, + epsilon=self.config.label_smoothing, + ) + + if do_marginalize: + logits = self.marginalize(logits, outputs.doc_scores) + + return RetrievAugLMMarginOutput( + loss=loss, + logits=logits, + doc_scores=outputs.doc_scores, + past_key_values=outputs.past_key_values, + context_input_ids=outputs.context_input_ids, + context_attention_mask=outputs.context_attention_mask, + retrieved_doc_embeds=outputs.retrieved_doc_embeds, + retrieved_doc_ids=outputs.retrieved_doc_ids, + question_encoder_last_hidden_state=outputs.question_encoder_last_hidden_state, + question_enc_hidden_states=outputs.question_enc_hidden_states, + question_enc_attentions=outputs.question_enc_attentions, + generator_enc_last_hidden_state=outputs.generator_enc_last_hidden_state, + generator_enc_hidden_states=outputs.generator_enc_hidden_states, + generator_enc_attentions=outputs.generator_enc_attentions, + generator_dec_hidden_states=outputs.generator_dec_hidden_states, + generator_dec_attentions=outputs.generator_dec_attentions, + ) + + @torch.no_grad() + def generate( + self, + input_ids: Optional[torch.LongTensor] = None, + context_input_ids=None, + context_attention_mask=None, + doc_scores=None, + max_length=None, + min_length=None, + early_stopping=None, + use_cache=None, + num_beams=None, + bos_token_id=None, + pad_token_id=None, + eos_token_id=None, + length_penalty=None, + no_repeat_ngram_size=None, + bad_words_ids=None, + num_return_sequences=None, + decoder_start_token_id=None, + **kwargs + ): + """ + Implements RAG token decoding. + + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + The sequence used as a prompt for the generation. If :obj:`input_ids` is not passed, then :obj:`context_input_ids` has to be provided. + context_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * config.n_docs, config.max_combined_length)`, `optional`, returned when `output_retrieved=True`): + Input ids post-processed from the retrieved documents and the question encoder input_ids by the retriever. + If the model has is not initialized with a ``retriever`` :obj:`context_input_ids` has to be provided to the forward pass. :obj:`context_input_ids` are returned by :meth:`~transformers.RagRetriever.__call__` + context_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size * config.n_docs, config.max_combined_length)`, `optional`, returned when `output_retrieved=True`): + Attention mask post-processed from the retrieved documents and the question encoder input_ids by the retriever. + If the model has is not initialized with a ``retriever`` :obj:`context_attention_mask` has to be provided to the forward pass. :obj:`context_attention_mask` are returned by :meth:`~transformers.RagRetriever.__call__` + doc_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.n_docs)`): + Score between each retrieved document embeddigs (see :obj:`retrieved_doc_embeds`) and :obj:`question_encoder_last_hidden_state`. + If the model has is not initialized with a ``retriever`` :obj:`doc_scores` has to be provided to the forward pass. :obj:`doc_scores` can be computed via :obj:`question_encoder_last_hidden_state` and :obj:`retrieved_doc_embeds`, see examples for more information. + + max_length (:obj:`int`, `optional`, defaults to 20): + The maximum length of the sequence to be generated. + min_length (:obj:`int`, `optional`, defaults to 10): + The minimum length of the sequence to be generated. + early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not. + use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not the model should use the past last key/values attentions (if applicable to the model) to + speed up decoding. + pad_token_id (:obj:`int`, `optional`): + The id of the `padding` token. + bos_token_id (:obj:`int`, `optional`): + The id of the `beginning-of-sequence` token. + eos_token_id (:obj:`int`, `optional`): + The id of the `end-of-sequence` token. + length_penalty (:obj:`float`, `optional`, defaults to 1.0): + Exponential penalty to the length. 1.0 means no penalty. + + Set to values < 1.0 in order to encourage the model to generate shorter sequences, to a value > 1.0 in + order to encourage the model to produce longer sequences. + no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0): + If set to int > 0, all ngrams of that size can only occur once. + bad_words_ids(:obj:`List[int]`, `optional`): + List of token ids that are not allowed to be generated. In order to get the tokens of the words that + should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`. + num_beams (:obj:`int`, `optional`, defaults to 1): + Number of beams for beam search. 1 means no beam search. + num_return_sequences(:obj:`int`, `optional`, defaults to 1): + The number of independently computed returned sequences for each element in the batch. Note that this is not the value + we pass to the ``generator``'s `:func:`~transformers.PreTrainedModel.generate`` function, where we set ``num_return_sequences`` + to `num_beams`. + decoder_start_token_id (:obj:`int`, `optional`): + If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token. + + Return: + + :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: + The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or + shorter if all batches finished early due to the :obj:`eos_token_id`. + """ + # set default parameters + max_length = max_length if max_length is not None else self.config.max_length + min_length = min_length if min_length is not None else self.config.min_length + early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping + use_cache = use_cache if use_cache is not None else self.config.use_cache + num_beams = num_beams if num_beams is not None else self.config.num_beams + bos_token_id = bos_token_id if bos_token_id is not None else self.config.generator.bos_token_id + pad_token_id = pad_token_id if pad_token_id is not None else self.config.generator.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.generator.eos_token_id + length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty + no_repeat_ngram_size = ( + no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size + ) + bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids + num_return_sequences = ( + num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences + ) + decoder_start_token_id = ( + decoder_start_token_id + if decoder_start_token_id is not None + else self.config.generator.decoder_start_token_id + ) + + # batch_size + batch_size = input_ids.shape[0] + + # retrieve docs + if self.retriever is not None and context_input_ids is None: + question_hidden_states = self.question_encoder(input_ids)[0] + out = self.retriever( + input_ids, + question_hidden_states.cpu().detach().to(torch.float32).numpy(), + prefix=self.generator.config.prefix, + n_docs=self.config.n_docs, + return_tensors="pt", + ) + context_input_ids, context_attention_mask, retrieved_doc_embeds = ( + out["context_input_ids"], + out["context_attention_mask"], + out["retrieved_doc_embeds"], + ) + + # set to correct device + retrieved_doc_embeds = retrieved_doc_embeds.to(question_hidden_states) + context_input_ids = context_input_ids.to(input_ids) + context_attention_mask = context_attention_mask.to(input_ids) + + # compute doc_scores + doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), retrieved_doc_embeds.transpose(1, 2)).squeeze( + 1 + ) + + encoder = self.rag.generator.get_encoder() + encoder_outputs = encoder(input_ids=context_input_ids, attention_mask=context_attention_mask, return_dict=True) + + decoder_input_ids = torch.full( + (batch_size * num_beams, 1), + decoder_start_token_id, + dtype=torch.long, + device=next(self.parameters()).device, + ) + last_hidden_state = encoder_outputs["last_hidden_state"] + + def extend_enc_output(tensor, num_beams=None): + # split into `batch_size`, `num_beams`, `num_docs` + tensor = tensor[None, None, :].reshape((batch_size, 1, self.config.n_docs) + tensor.shape[1:]) + # repeat same last hidden states over `num_beams` dimension + tensor = tensor.expand((batch_size, num_beams, self.config.n_docs) + tensor.shape[3:]) + # merge `batch_size`, `num_beams`, `num_docs` dims again + return tensor.reshape((batch_size * num_beams * self.config.n_docs,) + tensor.shape[3:]) + + # correctly extend last_hidden_state and attention mask + context_attention_mask = extend_enc_output(context_attention_mask, num_beams=num_beams) + encoder_outputs["last_hidden_state"] = extend_enc_output(last_hidden_state, num_beams=num_beams) + + doc_scores = doc_scores.repeat_interleave(num_beams, dim=0) + + # define start_len & additional parameters + cur_len = 1 + vocab_size = self.config.generator.vocab_size + kwargs["doc_scores"] = doc_scores + kwargs["encoder_outputs"] = encoder_outputs + + # not needed. TODO(PVP): change after generate refactor + do_sample = False + temperature = self.config.temperature + top_k = self.config.top_k + top_p = self.config.top_p + repetition_penalty = self.config.repetition_penalty + + if num_beams > 1: + return self._generate_beam_search( + decoder_input_ids, + cur_len=cur_len, + max_length=max_length, + min_length=min_length, + do_sample=do_sample, + early_stopping=early_stopping, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + bad_words_ids=bad_words_ids, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + batch_size=batch_size, + num_return_sequences=num_return_sequences, + length_penalty=length_penalty, + num_beams=num_beams, + vocab_size=vocab_size, + attention_mask=context_attention_mask, + use_cache=use_cache, + model_kwargs=kwargs, + ) + else: + return self._generate_no_beam_search( + decoder_input_ids, + cur_len=cur_len, + max_length=max_length, + min_length=min_length, + do_sample=do_sample, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + bad_words_ids=bad_words_ids, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + batch_size=batch_size, + attention_mask=context_attention_mask, + use_cache=use_cache, + model_kwargs=kwargs, + ) + + def get_input_embeddings(self): + return self.rag.generator.get_input_embeddings() + + def get_output_embeddings(self): + return self.rag.generator.get_output_embeddings() + + def shift_tokens_right(self, input_ids, start_token_id=None): + """Shift input ids one token to the right, and pad with start_token_id""" + if start_token_id is None: + start_token_id = self.config.decoder_start_token_id + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = start_token_id + return shifted_input_ids + + def get_nll(self, seq_logits, doc_scores, target, reduce_loss=False, epsilon=0.0): + # shift tokens left + target = torch.cat( + [target[:, 1:], target.new(target.shape[0], 1).fill_(self.config.generator.pad_token_id)], 1 + ) + + def _mask_pads(ll, smooth_obj): + pad_mask = target.eq(self.config.generator.pad_token_id) + if pad_mask.any(): + ll.masked_fill_(pad_mask, 0.0) + smooth_obj.masked_fill_(pad_mask, 0.0) + return ll.squeeze(-1), smooth_obj.squeeze(-1) + + rag_logprobs = self.marginalize(seq_logits, doc_scores) + + target = target.unsqueeze(-1) + assert target.dim() == rag_logprobs.dim() + + ll = rag_logprobs.gather(dim=-1, index=target) + smooth_obj = rag_logprobs.sum(dim=-1, keepdim=True) # total sum of all (normalised) logits + ll, smooth_obj = _mask_pads(ll, smooth_obj) + ll = ll.sum(1) # sum over tokens + smooth_obj = smooth_obj.sum(1) + + nll_loss = -ll + smooth_loss = -smooth_obj + + if reduce_loss: + nll_loss = nll_loss.sum() + smooth_loss = smooth_loss.sum() + + eps_i = epsilon / rag_logprobs.size(-1) + loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss + return loss diff --git a/src/transformers/retrieval_rag.py b/src/transformers/retrieval_rag.py new file mode 100644 index 000000000..644cf666e --- /dev/null +++ b/src/transformers/retrieval_rag.py @@ -0,0 +1,470 @@ +# coding=utf-8 +# Copyright 2020, The RAG Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""RAG Retriever model implementation.""" + +import os +import pickle +import time +from typing import Iterable, List, Optional, Tuple + +import numpy as np + +from .configuration_rag import RagConfig +from .file_utils import cached_path, is_datasets_available, is_faiss_available, is_remote_url +from .tokenization_rag import RagTokenizer +from .tokenization_utils_base import BatchEncoding +from .utils import logging + + +if is_datasets_available() and is_faiss_available(): + from datasets import load_dataset + + import faiss + + +logger = logging.get_logger(__name__) + + +LEGACY_INDEX_PATH = "https://storage.googleapis.com/huggingface-nlp/datasets/wiki_dpr/" + + +class Index: + """ + A base class for the Indices encapsulated by the :class:`~transformers.RagRetriever`. + """ + + def get_doc_dicts(self, doc_ids: np.ndarray) -> List[dict]: + """ + Returns a list of dictionaries, containing titles and text of the retrieved documents. + + Args: + doc_ids (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs)`): + A tensor of document indices. + """ + raise NotImplementedError + + def get_top_docs(self, question_hidden_states: np.ndarray, n_docs=5) -> Tuple[np.ndarray, np.ndarray]: + """ + For each query in the batch, retrieves ``n_docs`` documents. + + Args: + question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size): + An array of query vectors. + n_docs (:obj:`int`): + The number of docs retrieved per query. + + Returns: + :obj:`np.ndarray` of shape :obj:`(batch_size, n_docs)`: A tensor of indices of retrieved documents. + :obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`: A tensor of vector representations of retrieved documents. + """ + raise NotImplementedError + + def is_initialized(self): + """ + Returns :obj:`True` if index is already initialized. + """ + raise NotImplementedError + + def init_index(self): + """ + A function responsible for loading the index into memory. Should be called only once per training run of a RAG model. + E.g. if the model is trained on multiple GPUs in a distributed setup, only one of the workers will load the index. + """ + raise NotImplementedError + + +class LegacyIndex: + """ + An index which can be deserialized from the files built using https://github.com/facebookresearch/DPR. + We use default faiss index parameters as specified in that repository. + + Args: + vector_size (:obj:`int`): + The dimension of indexed vectors. + index_path (:obj:`str`): + A path to a `directory` containing index files compatible with + :class:`~transformers.retrieval_rag.LegacyIndex` + """ + + INDEX_FILENAME = "hf_bert_base.hnswSQ8_correct_phi_128.c_index" + PASSAGE_FILENAME = "psgs_w100.tsv.pkl" + + def __init__(self, vector_size, index_path): + self.index_id_to_db_id = [] + self.index_path = index_path + self.passages = self._load_passages() + self.vector_size = vector_size + self.index = None + self._index_initialize = False + + def _resolve_path(self, index_path, filename): + assert os.path.isdir(index_path) or is_remote_url(index_path), "Please specify a valid ``index_path``." + archive_file = os.path.join(index_path, filename) + try: + # Load from URL or cache if already cached + resolved_archive_file = cached_path(archive_file) + if resolved_archive_file is None: + raise EnvironmentError + except EnvironmentError: + msg = ( + f"Can't load '{archive_file}'. Make sure that:\n\n" + f"- '{index_path}' is a correct remote path to a directory containing a file named {filename}" + f"- or '{index_path}' is the correct path to a directory containing a file named {filename}.\n\n" + ) + raise EnvironmentError(msg) + if resolved_archive_file == archive_file: + logger.info("loading file {}".format(archive_file)) + else: + logger.info("loading file {} from cache at {}".format(archive_file, resolved_archive_file)) + return resolved_archive_file + + def _load_passages(self): + logger.info("Loading passages from {}".format(self.index_path)) + passages_path = self._resolve_path(self.index_path, self.PASSAGE_FILENAME) + with open(passages_path, "rb") as passages_file: + passages = pickle.load(passages_file) + return passages + + def _deserialize_index(self): + logger.info("Loading index from {}".format(self.index_path)) + resolved_index_path = self._resolve_path(self.index_path, self.INDEX_FILENAME + ".index.dpr") + self.index = faiss.read_index(resolved_index_path) + resolved_meta_path = self._resolve_path(self.index_path, self.INDEX_FILENAME + ".index_meta.dpr") + with open(resolved_meta_path, "rb") as metadata_file: + self.index_id_to_db_id = pickle.load(metadata_file) + assert ( + len(self.index_id_to_db_id) == self.index.ntotal + ), "Deserialized index_id_to_db_id should match faiss index size" + + def is_initialized(self): + return self._index_initialize + + def init_index(self): + index = faiss.IndexHNSWFlat(self.vector_size + 1, 512) + index.hnsw.efSearch = 128 + index.hnsw.efConstruction = 200 + self.index = index + self._deserialize_index() + self._index_initialize = True + + def get_doc_dicts(self, doc_ids: np.array): + doc_list = [] + for doc_ids_i in doc_ids: + ids = [str(int(doc_id)) for doc_id in doc_ids_i] + docs = [self.passages[doc_id] for doc_id in ids] + doc_list.append(docs) + doc_dicts = [] + for docs in doc_list: + doc_dict = {} + doc_dict["title"] = [doc[1] for doc in docs] + doc_dict["text"] = [doc[0] for doc in docs] + doc_dicts.append(doc_dict) + return doc_dicts + + def get_top_docs(self, question_hidden_states: np.ndarray, n_docs=5) -> Tuple[np.ndarray, np.ndarray]: + aux_dim = np.zeros(len(question_hidden_states), dtype="float32").reshape(-1, 1) + query_nhsw_vectors = np.hstack((question_hidden_states, aux_dim)) + _, docs_ids = self.index.search(query_nhsw_vectors, n_docs) + vectors = [[self.index.reconstruct(int(doc_id))[:-1] for doc_id in doc_ids] for doc_ids in docs_ids] + ids = [[int(self.index_id_to_db_id[doc_id]) for doc_id in doc_ids] for doc_ids in docs_ids] + return np.array(ids), np.array(vectors) + + +class HFIndex: + """ + A wrapper around an instance of :class:`~datasets.Datasets`. If ``index_path`` is set to ``None``, + we load the pre-computed index available with the :class:`~datasets.arrow_dataset.Dataset`, otherwise, we load the index from the indicated path on disk. + + Args: + dataset (:obj:`str`, optional, defaults to ``wiki_dpr``): + A datatset identifier of the indexed dataset on HuggingFace AWS bucket (list all available datasets and ids with ``datasets.list_datasets()``). + dataset_split (:obj:`str`, optional, defaults to ``train``) + Which split of the ``dataset`` to load. + index_name (:obj:`str`, optional, defaults to ``train``) + The index_name of the index associated with the ``dataset``. The index loaded from ``index_path`` will be saved under this name. + index_path (:obj:`str`, optional, defaults to ``None``) + The path to the serialized faiss index on disk. + """ + + def __init__( + self, + dataset_name: str, + dataset_split: str, + index_name: str, + index_path: Optional[str] = None, + use_dummy_dataset=False, + ): + super().__init__() + self.dataset_name = dataset_name + self.dataset_split = dataset_split + self.index_name = index_name + self.index_path = index_path + self.use_dummy_dataset = use_dummy_dataset + self._index_initialize = False + + logger.info("Loading passages from {}".format(self.dataset_name)) + self.dataset = load_dataset( + self.dataset_name, with_index=False, split=self.dataset_split, dummy=self.use_dummy_dataset + ) + + def is_initialized(self): + return self._index_initialize + + def init_index(self): + if self.index_path is not None: + logger.info("Loading index from {}".format(self.index_path)) + self.index.load_faiss_index(index_name=self.index_name, file=self.index_path) + else: + logger.info("Loading index from {}".format(self.dataset_name + " with index name " + self.index_name)) + self.dataset = load_dataset( + self.dataset_name, + with_embeddings=True, + with_index=True, + split=self.dataset_split, + index_name=self.index_name, + dummy=self.use_dummy_dataset, + ) + self._index_initialize = True + + def get_doc_dicts(self, doc_ids: np.ndarray) -> List[dict]: + return [self.dataset[doc_ids[i].tolist()] for i in range(doc_ids.shape[0])] + + def get_top_docs(self, question_hidden_states: np.ndarray, n_docs=5) -> Tuple[np.ndarray, np.ndarray]: + _, docs = self.dataset.get_nearest_examples_batch("embeddings", question_hidden_states, n_docs) + ids = [[int(i) for i in doc["id"]] for doc in docs] + vectors = [doc["embeddings"] for doc in docs] + return np.array(ids), np.array(vectors) # shapes (batch_size, n_docs) and (batch_size, n_docs, d) + + +class RagRetriever: + """ + Retriever used to get documents from vector queries. + It retrieves the documents embeddings as well as the documents contents, and it formats them to be used with a RagModel. + + Args: + config (:class:`~transformers.RagConfig`): + The configuration of the RAG model this Retriever is used with. Contains parameters indicating which ``Index`` to build. + question_encoder_tokenizer (:class:`~transformers.PreTrainedTokenizer`): + The tokenizer that was used to tokenize the question. + It is used to decode the question and then use the generator_tokenizer. + generator_tokenizer (:class:`~transformers.PreTrainedTokenizer`): + The tokenizer used for the generator part of the RagModel. + """ + + _init_retrieval = True + + def __init__(self, config, question_encoder_tokenizer, generator_tokenizer): + super().__init__() + self.index = ( + LegacyIndex( + config.retrieval_vector_size, + config.index_path or LEGACY_INDEX_PATH, + ) + if config.index_name == "legacy" + else HFIndex( + config.dataset, config.dataset_split, config.index_name, config.index_path, config.use_dummy_dataset + ) + ) + self.generator_tokenizer = generator_tokenizer + self.question_encoder_tokenizer = question_encoder_tokenizer + + self.n_docs = config.n_docs + self.batch_size = config.retrieval_batch_size + + self.config = config + if self._init_retrieval: + self.init_retrieval() + + @classmethod + def from_pretrained(cls, retriever_name_or_path, **kwargs): + config = RagConfig.from_pretrained(retriever_name_or_path, **kwargs) + rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config) + question_encoder_tokenizer = rag_tokenizer.question_encoder + generator_tokenizer = rag_tokenizer.generator + return cls( + config, question_encoder_tokenizer=question_encoder_tokenizer, generator_tokenizer=generator_tokenizer + ) + + def save_pretrained(self, save_directory): + self.config.save_pretrained(save_directory) + rag_tokenizer = RagTokenizer( + question_encoder_tokenizer=self.question_encoder_tokenizer, + generator_tokenizer=self.generator_tokenizer, + ) + rag_tokenizer.save_pretrained(save_directory) + + def init_retrieval(self): + """ + Retriever initalization function. It loads the index into memory. + """ + + logger.info("initializing retrieval") + self.index.init_index() + + def postprocess_docs(self, docs, input_strings, prefix, n_docs, return_tensors=None): + r""" + Postprocessing retrieved ``docs`` and combining them with ``input_strings``. + + Args: + doc_scores (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs)`): + Retrieval scores of respective docs - passed for logging. + docs (:obj:`dict`): + Retrieved documents. + input_strings (:obj:`str`): + Input strings decoded by ``preprocess_query``. + prefix (:obj:`str`): + Prefix added at the beginning of each input, typically used with T5-based models. + + Return: + :obj:`tuple(tensors)`: + a tuple consisting of two elements: contextualized ``input_ids`` and a compatible ``attention_mask``. + """ + + def cat_input_and_doc(doc_title, doc_text, input_string, prefix): + # TODO(Patrick): if we train more RAG models, I want to put the input first to take advantage of effortless truncation + # TODO(piktus): better handling of truncation + if doc_title.startswith('"'): + doc_title = doc_title[1:] + if doc_title.endswith('"'): + doc_title = doc_title[:-1] + if prefix is None: + prefix = "" + out = (prefix + doc_title + self.config.title_sep + doc_text + self.config.doc_sep + input_string).replace( + " ", " " + ) + return out + + rag_input_strings = [ + cat_input_and_doc( + docs[i]["title"][j], + docs[i]["text"][j], + input_strings[i], + prefix, + ) + for i in range(len(docs)) + for j in range(n_docs) + ] + + contextualized_inputs = self.generator_tokenizer.batch_encode_plus( + rag_input_strings, + max_length=self.config.max_combined_length, + return_tensors=return_tensors, + padding="max_length", + truncation=True, + ) + + return contextualized_inputs["input_ids"], contextualized_inputs["attention_mask"] + + def _chunk_tensor(self, t: Iterable, chunk_size: int) -> List[Iterable]: + return [t[i : i + chunk_size] for i in range(0, len(t), chunk_size)] + + def _main_retrieve(self, question_hidden_states: np.ndarray, n_docs: int) -> Tuple[np.ndarray, np.ndarray]: + question_hidden_states_batched = self._chunk_tensor(question_hidden_states, self.batch_size) + ids_batched = [] + vectors_batched = [] + for question_hidden_states in question_hidden_states_batched: + start_time = time.time() + ids, vectors = self.index.get_top_docs(question_hidden_states, n_docs) + logger.debug( + "index search time: {} sec, batch size {}".format( + time.time() - start_time, question_hidden_states.shape + ) + ) + ids_batched.extend(ids) + vectors_batched.extend(vectors) + return np.array(ids_batched), np.array( + vectors_batched + ) # shapes (batch_size, n_docs) and (batch_size, n_docs, d) + + def retrieve(self, question_hidden_states: np.ndarray, n_docs: int) -> Tuple[np.ndarray, List[dict]]: + """ + Retrieves documents for specified ``question_hidden_states``. + + Args: + question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`): + A batch of query vectors to retrieve with. + n_docs (:obj:`int`): + The number of docs retrieved per query. + + Return: + retrieved_doc_embeds (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs, dim)` + The retrieval embeddings of the retrieved docs per query. + doc_ids (:obj:`np.ndarray` of shape :obj:`batch_size, n_docs`) + The ids of the documents in the index + doc_dicts (:obj:`List[dict]`): + The retrieved_doc_embeds examples per query. + """ + + doc_ids, retrieved_doc_embeds = self._main_retrieve(question_hidden_states, n_docs) + return retrieved_doc_embeds, doc_ids, self.index.get_doc_dicts(doc_ids) + + def __call__( + self, + question_input_ids: List[List[int]], + question_hidden_states: np.ndarray, + prefix=None, + n_docs=None, + return_tensors=None, + ) -> BatchEncoding: + """ + Retrieves documents for specified :obj:`question_hidden_states`. + + Args: + question_input_ids: (:obj:`List[List[int]]`) batch of input ids + question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`: + A batch of query vectors to retrieve with. + prefix: (:obj:`str`, `optional`): + The prefix used by the generator's tokenizer. + n_docs (:obj:`int`, `optional`): + The number of docs retrieved per query. + return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"): + If set, will return tensors instead of list of python integers. Acceptable values are: + + * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects. + * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects. + * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects. + + Output: + :class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields: + + - **context_input_ids** -- List of token ids to be fed to a model. + + `What are input IDs? <../glossary.html#input-ids>`__ + - **context_attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + :obj:`return_attention_mask=True` or if `"attention_mask"` is in :obj:`self.model_input_names`). + + `What are attention masks? <../glossary.html#attention-mask>`__ + - **retrieved_doc_embeds** -- List of embeddings of the retrieved documents + - **doc_ids** -- List of ids of the retrieved documents + """ + + n_docs = n_docs if n_docs is not None else self.n_docs + prefix = prefix if prefix is not None else self.config.generator.prefix + retrieved_doc_embeds, doc_ids, docs = self.retrieve(question_hidden_states, n_docs) + + input_strings = self.question_encoder_tokenizer.batch_decode(question_input_ids, skip_special_tokens=True) + context_input_ids, context_attention_mask = self.postprocess_docs( + docs, input_strings, prefix, n_docs, return_tensors=return_tensors + ) + + return BatchEncoding( + { + "context_input_ids": context_input_ids, + "context_attention_mask": context_attention_mask, + "retrieved_doc_embeds": retrieved_doc_embeds, + "doc_ids": doc_ids, + }, + tensor_type=return_tensors, + ) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 49e1dff95..d009ad3d5 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -10,7 +10,7 @@ from distutils.util import strtobool from io import StringIO from pathlib import Path -from .file_utils import _tf_available, _torch_available, _torch_tpu_available +from .file_utils import _datasets_available, _faiss_available, _tf_available, _torch_available, _torch_tpu_available SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy" @@ -161,6 +161,21 @@ def require_torch_and_cuda(test_case): return test_case +def require_datasets(test_case): + """Decorator marking a test that requires datasets.""" + + if not _datasets_available: + test_case = unittest.skip("test requires Datasets")(test_case) + return test_case + + +def require_faiss(test_case): + """Decorator marking a test that requires faiss.""" + if not _faiss_available: + test_case = unittest.skip("test requires Faiss")(test_case) + return test_case + + def get_tests_dir(): """ returns the full path to the `tests` dir, so that the tests can be invoked from anywhere diff --git a/src/transformers/tokenization_auto.py b/src/transformers/tokenization_auto.py index 984d2c0ee..320bc574e 100644 --- a/src/transformers/tokenization_auto.py +++ b/src/transformers/tokenization_auto.py @@ -26,6 +26,7 @@ from .configuration_auto import ( CamembertConfig, CTRLConfig, DistilBertConfig, + DPRConfig, ElectraConfig, EncoderDecoderConfig, FlaubertConfig, @@ -40,6 +41,7 @@ from .configuration_auto import ( MobileBertConfig, OpenAIGPTConfig, PegasusConfig, + RagConfig, ReformerConfig, RetriBertConfig, RobertaConfig, @@ -60,6 +62,7 @@ from .tokenization_bertweet import BertweetTokenizer from .tokenization_camembert import CamembertTokenizer from .tokenization_ctrl import CTRLTokenizer from .tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFast +from .tokenization_dpr import DPRQuestionEncoderTokenizer, DPRQuestionEncoderTokenizerFast from .tokenization_electra import ElectraTokenizer, ElectraTokenizerFast from .tokenization_flaubert import FlaubertTokenizer from .tokenization_fsmt import FSMTTokenizer @@ -74,6 +77,7 @@ from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFas from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast from .tokenization_pegasus import PegasusTokenizer from .tokenization_phobert import PhobertTokenizer +from .tokenization_rag import RagTokenizer from .tokenization_reformer import ReformerTokenizer from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast @@ -110,6 +114,7 @@ TOKENIZER_MAPPING = OrderedDict( (FunnelConfig, (FunnelTokenizer, FunnelTokenizerFast)), (LxmertConfig, (LxmertTokenizer, LxmertTokenizerFast)), (LayoutLMConfig, (LayoutLMTokenizer, LayoutLMTokenizerFast)), + (DPRConfig, (DPRQuestionEncoderTokenizer, DPRQuestionEncoderTokenizerFast)), (BertConfig, (BertTokenizer, BertTokenizerFast)), (OpenAIGPTConfig, (OpenAIGPTTokenizer, OpenAIGPTTokenizerFast)), (GPT2Config, (GPT2Tokenizer, GPT2TokenizerFast)), @@ -121,6 +126,7 @@ TOKENIZER_MAPPING = OrderedDict( (FSMTConfig, (FSMTTokenizer, None)), (BertGenerationConfig, (BertGenerationTokenizer, None)), (LayoutLMConfig, (LayoutLMTokenizer, None)), + (RagConfig, (RagTokenizer, None)), ] ) diff --git a/src/transformers/tokenization_rag.py b/src/transformers/tokenization_rag.py new file mode 100644 index 000000000..7307d41f6 --- /dev/null +++ b/src/transformers/tokenization_rag.py @@ -0,0 +1,160 @@ +# coding=utf-8 +# Copyright 2020, The RAG Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for RAG.""" +import os +from typing import List, Optional + +from .configuration_rag import RagConfig +from .tokenization_utils_base import BatchEncoding +from .utils import logging + + +logger = logging.get_logger(__name__) + + +class RagTokenizer: + def __init__(self, question_encoder, generator): + self.question_encoder = question_encoder + self.generator = generator + + def save_pretrained(self, save_directory): + if os.path.isfile(save_directory): + raise ValueError("Provided path ({}) should be a directory, not a file".format(save_directory)) + os.makedirs(save_directory, exist_ok=True) + question_encoder_path = os.path.join(save_directory, "question_encoder_tokenizer") + generator_path = os.path.join(save_directory, "generator_tokenizer") + self.question_encoder.save_pretrained(question_encoder_path) + self.generator.save_pretrained(generator_path) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + # dynamically import AutoTokenizer + from .tokenization_auto import AutoTokenizer + + config = kwargs.pop("config", None) + + if config is None: + config = RagConfig.from_pretrained(pretrained_model_name_or_path) + + question_encoder_path = os.path.join(pretrained_model_name_or_path, "question_encoder_tokenizer") + generator_path = os.path.join(pretrained_model_name_or_path, "generator_tokenizer") + question_encoder = AutoTokenizer.from_pretrained(question_encoder_path, config=config.question_encoder) + generator = AutoTokenizer.from_pretrained(generator_path, config=config.generator) + return cls(question_encoder=question_encoder, generator=generator) + + def __call__(self, *args, **kwargs): + return self.question_encoder(*args, **kwargs) + + def batch_decode(self, *args, **kwargs): + return self.generator.batch_decode(*args, **kwargs) + + def prepare_seq2seq_batch( + self, + src_texts: List[str], + tgt_texts: Optional[List[str]] = None, + max_length: Optional[int] = None, + max_target_length: Optional[int] = None, + padding: str = "longest", + return_tensors: str = "np", + truncation=True, + **kwargs, + ) -> BatchEncoding: + r""" + + Prepare a batch that can be passed directly to an instance of :class:`~transformers.RagModel`. + + Args: + src_texts: (:obj:`List[str]`): + List of documents to summarize or source language texts. + tgt_texts: (:obj:`List[str]`, `optional`): + List of summaries or target language texts. + max_length (:obj:`int`, `optional`): + Controls the maximum length for encoder inputs (documents to summarize or source language texts). + If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum + length is required by one of the truncation/padding parameters. If the model has no specific maximum + input length (like XLNet) truncation/padding to a maximum length will be deactivated. + max_target_length (:obj:`int`, `optional`): + Controls the maximum length of decoder inputs (target language texts or summaries). + If left unset or set to :obj:`None`, this will use the max_length value. + padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`): + Activates and controls padding. Accepts the following values: + + * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a + single sequence if provided). + * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the + maximum acceptable input length for the model if that argument is not provided. + * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of + different lengths). + return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"): + If set, will return tensors instead of list of python integers. Acceptable values are: + + * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects. + * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects. + * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects. + truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`): + Activates and controls truncation. Accepts the following values: + + * :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument + :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not + provided. This will truncate token by token, removing a token from the longest sequence in the pair + if a pair of sequences (or a batch of pairs) is provided. + * :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to + the maximum acceptable input length for the model if that argument is not provided. This will only + truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + * :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or + to the maximum acceptable input length for the model if that argument is not provided. This will only + truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + * :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with + sequence lengths greater than the model maximum admissible input size). + **kwargs: + Additional keyword arguments passed along to :obj:`self.__call__`. + + Returns: + :class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields: + + - **input_ids** -- List of token ids to be fed to the encoder. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model. + - **labels** -- List of token ids for tgt_texts + + The full set of keys ``[input_ids, attention_mask, labels]``, + will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys. + """ + if max_length is None: + max_length = self.question_encoder.model_max_length + model_inputs: BatchEncoding = self.question_encoder( + src_texts, + add_special_tokens=True, + return_tensors=return_tensors, + max_length=max_length, + padding=padding, + truncation=truncation, + **kwargs, + ) + if tgt_texts is None: + return model_inputs + # Process tgt_texts + if max_target_length is None: + max_target_length = self.generator.model_max_length + labels = self.generator( + tgt_texts, + add_special_tokens=True, + return_tensors=return_tensors, + padding=padding, + max_length=max_target_length, + truncation=truncation, + **kwargs, + )["input_ids"] + model_inputs["labels"] = labels + return model_inputs diff --git a/tests/test_modeling_rag.py b/tests/test_modeling_rag.py new file mode 100644 index 000000000..0e4442031 --- /dev/null +++ b/tests/test_modeling_rag.py @@ -0,0 +1,885 @@ +# coding=utf-8 +# Copyright 2020, The RAG Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +import os +import shutil +import tempfile +import unittest +from unittest.mock import patch + +import numpy as np + +from transformers.file_utils import cached_property, is_datasets_available, is_faiss_available, is_torch_available +from transformers.testing_utils import require_torch, slow, torch_device +from transformers.tokenization_bart import BartTokenizer +from transformers.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES +from transformers.tokenization_dpr import DPRQuestionEncoderTokenizer +from transformers.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES +from transformers.tokenization_t5 import T5Tokenizer + +from .test_modeling_bart import ModelTester as BartModelTester +from .test_modeling_dpr import DPRModelTester +from .test_modeling_t5 import T5ModelTester + + +TOLERANCE = 1e-3 + +T5_SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model") + +if is_torch_available() and is_datasets_available() and is_faiss_available(): + import torch + from datasets import Dataset + + import faiss + from transformers import ( + AutoConfig, + AutoModel, + AutoModelForSeq2SeqLM, + RagConfig, + RagModel, + RagRetriever, + RagSequenceForGeneration, + RagTokenForGeneration, + ) + from transformers.modeling_outputs import BaseModelOutput + + +def _assert_tensors_equal(a, b, atol=1e-12, prefix=""): + """If tensors not close, or a and b arent both tensors, raise a nice Assertion error.""" + if a is None and b is None: + return True + try: + if torch.allclose(a, b, atol=atol): + return True + raise + except Exception: + msg = "{} != {}".format(a, b) + if prefix: + msg = prefix + ": " + msg + raise AssertionError(msg) + + +def require_retrieval(test_case): + """ + Decorator marking a test that requires a set of dependencies necessary for pefrorm retrieval with + :class:`~transformers.RagRetriever`. + + These tests are skipped when respective libraries are not installed. + + """ + if not (is_torch_available() and is_datasets_available() and is_faiss_available()): + test_case = unittest.skip("test requires PyTorch")(test_case) + return test_case + + +@require_torch +@require_retrieval +class RagTestMixin: + + all_model_classes = ( + (RagModel, RagTokenForGeneration, RagSequenceForGeneration) + if is_torch_available() and is_datasets_available() and is_faiss_available() + else () + ) + + retrieval_vector_size = 32 + n_docs = 2 + max_combined_length = 16 + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + + # DPR tok + vocab_tokens = [ + "[UNK]", + "[CLS]", + "[SEP]", + "[PAD]", + "[MASK]", + "want", + "##want", + "##ed", + "wa", + "un", + "runn", + "##ing", + ",", + "low", + "lowest", + ] + dpr_tokenizer_path = os.path.join(self.tmpdirname, "dpr_tokenizer") + os.makedirs(dpr_tokenizer_path, exist_ok=True) + self.vocab_file = os.path.join(dpr_tokenizer_path, DPR_VOCAB_FILES_NAMES["vocab_file"]) + with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer: + vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) + + # BART tok + vocab = [ + "l", + "o", + "w", + "e", + "r", + "s", + "t", + "i", + "d", + "n", + "\u0120", + "\u0120l", + "\u0120n", + "\u0120lo", + "\u0120low", + "er", + "\u0120lowest", + "\u0120newer", + "\u0120wider", + "", + ] + vocab_tokens = dict(zip(vocab, range(len(vocab)))) + merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""] + self.special_tokens_map = {"unk_token": ""} + + bart_tokenizer_path = os.path.join(self.tmpdirname, "bart_tokenizer") + os.makedirs(bart_tokenizer_path, exist_ok=True) + self.vocab_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["vocab_file"]) + self.merges_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["merges_file"]) + with open(self.vocab_file, "w", encoding="utf-8") as fp: + fp.write(json.dumps(vocab_tokens) + "\n") + with open(self.merges_file, "w", encoding="utf-8") as fp: + fp.write("\n".join(merges)) + + t5_tokenizer = T5Tokenizer(T5_SAMPLE_VOCAB) + t5_tokenizer_path = os.path.join(self.tmpdirname, "t5_tokenizer") + t5_tokenizer.save_pretrained(t5_tokenizer_path) + + @cached_property + def dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer: + return DPRQuestionEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer")) + + @cached_property + def bart_tokenizer(self) -> BartTokenizer: + return BartTokenizer.from_pretrained(os.path.join(self.tmpdirname, "bart_tokenizer")) + + @cached_property + def t5_tokenizer(self) -> BartTokenizer: + return T5Tokenizer.from_pretrained(os.path.join(self.tmpdirname, "t5_tokenizer")) + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + def get_retriever(self, config): + dataset = Dataset.from_dict( + { + "id": ["0", "1"], + "text": ["foo", "bar"], + "title": ["Foo", "Bar"], + "embeddings": [np.ones(self.retrieval_vector_size), 2 * np.ones(self.retrieval_vector_size)], + } + ) + dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT) + tokenizer = self.bart_tokenizer if config.generator.model_type == "bart" else self.t5_tokenizer + with patch("transformers.retrieval_rag.load_dataset") as mock_load_dataset: + mock_load_dataset.return_value = dataset + retriever = RagRetriever( + config, + question_encoder_tokenizer=self.dpr_tokenizer, + generator_tokenizer=tokenizer, + ) + return retriever + + def check_model_with_retriever( + self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs + ): + self.assertIsNotNone(config.question_encoder) + self.assertIsNotNone(config.generator) + + for model_class in self.all_model_classes: + model = model_class(config, retriever=self.get_retriever(config)).to(torch_device) + model.eval() + + self.assertTrue(model.config.is_encoder_decoder) + + outputs = model( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + ) + + # logits + self.assertEqual( + outputs.logits.shape, + (self.n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size), + ) + # generator encoder last hidden states + self.assertEqual( + outputs.generator_enc_last_hidden_state.shape, + (self.n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size), + ) + # doc scores + self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs)) + + def check_model_generate( + self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs + ): + self.assertIsNotNone(config.question_encoder) + self.assertIsNotNone(config.generator) + + for model_class in self.all_model_classes[1:]: + model = model_class(config, retriever=self.get_retriever(config)).to(torch_device) + model.eval() + + self.assertTrue(model.config.is_encoder_decoder) + + outputs = model.generate( + input_ids=input_ids, + num_beams=2, + num_return_sequences=2, + decoder_start_token_id=config.generator.eos_token_id, + ) + + self.assertIsNotNone(outputs) + + def check_model_without_retriever( + self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs + ): + self.assertIsNotNone(config.question_encoder) + self.assertIsNotNone(config.generator) + + retriever = self.get_retriever(config) + + for model_class in self.all_model_classes: + model = model_class(config).to(torch_device) + model.eval() + self.assertTrue(model.config.is_encoder_decoder) + + question_hidden_states = model.question_encoder(input_ids, attention_mask=attention_mask)[0] + + out = retriever( + input_ids, + question_hidden_states.cpu().detach().to(torch.float32).numpy(), + prefix=config.generator.prefix, + return_tensors="pt", + ) + + context_input_ids, context_attention_mask, retrieved_doc_embeds = ( + out["context_input_ids"], + out["context_attention_mask"], + out["retrieved_doc_embeds"], + ) + + # cast + retrieved_doc_embeds = retrieved_doc_embeds.to(question_hidden_states) + context_input_ids = context_input_ids.to(input_ids) + context_attention_mask = context_attention_mask.to(input_ids) + + # compute doc_scores + doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), retrieved_doc_embeds.transpose(1, 2)).squeeze( + 1 + ) + + outputs = model( + context_input_ids=context_input_ids, + context_attention_mask=context_attention_mask, + doc_scores=doc_scores, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + ) + + # logits + self.assertEqual( + outputs.logits.shape, + (self.n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size), + ) + # generator encoder last hidden states + self.assertEqual( + outputs.generator_enc_last_hidden_state.shape, + (self.n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size), + ) + # doc scores + self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs)) + + def check_model_with_encoder_outputs( + self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs + ): + self.assertIsNotNone(config.question_encoder) + self.assertIsNotNone(config.generator) + + for model_class in self.all_model_classes: + model = model_class(config, retriever=self.get_retriever(config)).to(torch_device) + model.eval() + + self.assertTrue(model.config.is_encoder_decoder) + + outputs = model( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + ) + + encoder_outputs = BaseModelOutput(outputs.generator_enc_last_hidden_state) + + # run only generator + outputs = model( + encoder_outputs=encoder_outputs, + doc_scores=outputs.doc_scores, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + ) + + # logits + self.assertEqual( + outputs.logits.shape, + (self.n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size), + ) + # generator encoder last hidden states + self.assertEqual( + outputs.generator_enc_last_hidden_state.shape, + (self.n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size), + ) + # doc scores + self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs)) + + def test_model_with_retriever(self): + inputs_dict = self.config_and_inputs + self.check_model_with_retriever(**inputs_dict) + + def test_model_without_retriever(self): + inputs_dict = self.config_and_inputs + self.check_model_without_retriever(**inputs_dict) + + def test_model_with_encoder_outputs(self): + inputs_dict = self.config_and_inputs + self.check_model_with_encoder_outputs(**inputs_dict) + + def test_model_generate(self): + inputs_dict = self.config_and_inputs + self.check_model_generate(**inputs_dict) + + +@require_torch +@require_retrieval +class RagDPRBartTest(RagTestMixin, unittest.TestCase): + @cached_property + def config_and_inputs(self): + question_encoder_tester = DPRModelTester(self) + dpr_config_and_inputs = question_encoder_tester.prepare_config_and_inputs() + generator_tester = BartModelTester(self) + bart_config_and_inputs = generator_tester.prepare_config_and_inputs_for_common() + + (question_encoder_config, input_ids, _, input_mask, _, _, _) = dpr_config_and_inputs + (generator_config, bart_inputs_dict) = bart_config_and_inputs + decoder_input_ids, decoder_attention_mask = bart_inputs_dict["input_ids"], bart_inputs_dict["attention_mask"] + + config = RagConfig.from_question_encoder_generator_configs( + question_encoder_config, + generator_config, + n_docs=self.n_docs, + retrieval_vector_size=self.retrieval_vector_size, + max_combined_length=self.max_combined_length, + use_cache=False, + ) + + return { + "config": config, + "input_ids": input_ids, + "attention_mask": input_mask, + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": decoder_attention_mask, + } + + +@require_torch +@require_retrieval +class RagDPRT5Test(RagTestMixin, unittest.TestCase): + @cached_property + def config_and_inputs(self): + question_encoder_tester = DPRModelTester(self) + dpr_config_and_inputs = question_encoder_tester.prepare_config_and_inputs() + generator_tester = T5ModelTester(self, vocab_size=1100, n_positions=30) + t5_config_and_inputs = generator_tester.prepare_config_and_inputs() + + (question_encoder_config, input_ids, _, input_mask, _, _, _) = dpr_config_and_inputs + # import ipdb; ipdb.set_trace() + (generator_config, _, decoder_input_ids, _, decoder_attention_mask, _) = t5_config_and_inputs + config = RagConfig.from_question_encoder_generator_configs( + question_encoder_config, + generator_config, + n_docs=self.n_docs, + retrieval_vector_size=self.retrieval_vector_size, + max_combined_length=self.max_combined_length, + use_cache=False, + ) + + return { + "config": config, + "input_ids": input_ids, + "attention_mask": input_mask, + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": decoder_attention_mask, + } + + +@require_torch +@require_retrieval +class RagModelIntegrationTests(unittest.TestCase): + @cached_property + def sequence_model(self): + return ( + RagSequenceForGeneration.from_pretrained_question_encoder_generator( + "facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large-cnn" + ) + .to(torch_device) + .eval() + ) + + @cached_property + def token_model(self): + return ( + RagTokenForGeneration.from_pretrained_question_encoder_generator( + "facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large-cnn" + ) + .to(torch_device) + .eval() + ) + + def get_rag_config(self): + question_encoder_config = AutoConfig.from_pretrained("facebook/dpr-question_encoder-single-nq-base") + generator_config = AutoConfig.from_pretrained("facebook/bart-large-cnn") + return RagConfig.from_question_encoder_generator_configs( + question_encoder_config, + generator_config, + bos_token_id=0, + decoder_start_token_id=2, + eos_token_id=2, + is_encoder_decoder=True, + pad_token_id=1, + vocab_size=50264, + title_sep=" / ", + doc_sep=" // ", + n_docs=5, + max_combined_length=300, + dataset="wiki_dpr", + dataset_split="train", + index_name="exact", + index_path=None, + use_dummy_dataset=True, + retrieval_vector_size=768, + retrieval_batch_size=8, + ) + + @slow + def test_rag_sequence_inference(self): + rag_config = self.get_rag_config() + rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") + rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( + "facebook/dpr-question_encoder-single-nq-base" + ) + rag_retriever = RagRetriever( + rag_config, + question_encoder_tokenizer=rag_question_encoder_tokenizer, + generator_tokenizer=rag_decoder_tokenizer, + ) + + rag_sequence = self.sequence_model + rag_sequence.set_retriever(rag_retriever) + + input_ids = rag_question_encoder_tokenizer( + "who sings does he love me with reba", return_tensors="pt" + ).input_ids + decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="pt").input_ids + + input_ids = input_ids.to(torch_device) + decoder_input_ids = decoder_input_ids.to(torch_device) + + with torch.no_grad(): + output = rag_sequence( + input_ids, + labels=decoder_input_ids, + ) + + expected_shape = torch.Size([5, 5, 50264]) + self.assertEqual(output.logits.shape, expected_shape) + + expected_doc_scores = torch.tensor([[75.0286, 74.4998, 74.0804, 74.0306, 73.9504]]).to(torch_device) + _assert_tensors_equal(expected_doc_scores, output.doc_scores, atol=TOLERANCE) + + expected_loss = torch.tensor([38.7446]).to(torch_device) + _assert_tensors_equal(expected_loss, output.loss, atol=TOLERANCE) + + @slow + def test_rag_token_inference(self): + rag_config = self.get_rag_config() + rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") + rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( + "facebook/dpr-question_encoder-single-nq-base" + ) + rag_retriever = RagRetriever( + rag_config, + question_encoder_tokenizer=rag_question_encoder_tokenizer, + generator_tokenizer=rag_decoder_tokenizer, + ) + + rag_token = self.token_model + rag_token.set_retriever(rag_retriever) + + input_ids = rag_question_encoder_tokenizer( + "who sings does he love me with reba", return_tensors="pt" + ).input_ids + decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="pt").input_ids + + input_ids = input_ids.to(torch_device) + decoder_input_ids = decoder_input_ids.to(torch_device) + + with torch.no_grad(): + output = rag_token( + input_ids, + labels=decoder_input_ids, + ) + + expected_shape = torch.Size([5, 5, 50264]) + self.assertEqual(output.logits.shape, expected_shape) + + expected_doc_scores = torch.tensor([[75.0286, 74.4998, 74.0804, 74.0306, 73.9504]]).to(torch_device) + _assert_tensors_equal(expected_doc_scores, output.doc_scores, atol=TOLERANCE) + + expected_loss = torch.tensor([38.7045]).to(torch_device) + _assert_tensors_equal(expected_loss, output.loss, atol=TOLERANCE) + + @slow + def test_rag_token_generate_beam(self): + rag_config = self.get_rag_config() + rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") + rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( + "facebook/dpr-question_encoder-single-nq-base" + ) + rag_retriever = RagRetriever( + rag_config, + question_encoder_tokenizer=rag_question_encoder_tokenizer, + generator_tokenizer=rag_decoder_tokenizer, + ) + + rag_token = self.token_model + rag_token.set_retriever(rag_retriever) + + input_ids = rag_question_encoder_tokenizer( + "who sings does he love me with reba", return_tensors="pt" + ).input_ids + + input_ids = input_ids.to(torch_device) + + output_ids = rag_token.generate( + input_ids, + decoder_start_token_id=rag_token.generator.config.decoder_start_token_id, + num_beams=2, + num_return_sequences=2, + ) + # sequence generate test + output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True) + output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True) + + # Expected outputs as given by model at integration time. + EXPECTED_OUTPUT_TEXT_1 = "The songwriting credits are credited to ABBA" + EXPECTED_OUTPUT_TEXT_2 = 'The songwriting credits are credited to "B' + + self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1) + self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2) + + @slow + def test_rag_token_generate_batch(self): + rag_config = self.get_rag_config() + rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") + rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( + "facebook/dpr-question_encoder-single-nq-base" + ) + rag_retriever = RagRetriever( + rag_config, + question_encoder_tokenizer=rag_question_encoder_tokenizer, + generator_tokenizer=rag_decoder_tokenizer, + ) + + rag_token = self.token_model + rag_token.set_retriever(rag_retriever) + + questions = [ + "who sings does he love me with reba", + "how many pages is invisible man by ralph ellison", + ] + input_ids = rag_question_encoder_tokenizer.batch_encode_plus( + questions, + return_tensors="pt", + padding=True, + truncation=True, + ).input_ids + + input_ids = input_ids.to(torch_device) + + output_ids = rag_token.generate( + input_ids, + decoder_start_token_id=rag_token.generator.config.decoder_start_token_id, + num_beams=4, + num_return_sequences=1, + max_length=10, + ) + + # sequence generate test + output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True) + output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True) + + # Expected outputs as given by model at integration time. + EXPECTED_OUTPUT_TEXT_1 = '"People Need Love" is the' + EXPECTED_OUTPUT_TEXT_2 = '"How many pages is invisible man' + + self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1) + self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2) + + @slow + def test_rag_sequence_generate_batch(self): + # IMPORTAN: This test fails on GPU, but is fine on CPU -> beam search is very sensible + rag_config = self.get_rag_config() + rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") + rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( + "facebook/dpr-question_encoder-single-nq-base" + ) + rag_retriever = RagRetriever( + rag_config, + question_encoder_tokenizer=rag_question_encoder_tokenizer, + generator_tokenizer=rag_decoder_tokenizer, + ) + + rag_sequence = self.sequence_model + rag_sequence.set_retriever(rag_retriever) + + questions = [ + "who sings does he love me with reba", + "how many pages is invisible man by ralph ellison", + ] + input_ids = rag_question_encoder_tokenizer.batch_encode_plus( + questions, + return_tensors="pt", + padding=True, + truncation=True, + ).input_ids + + input_ids = input_ids.to(torch_device) + + output_ids = rag_sequence.generate( + input_ids, + decoder_start_token_id=rag_sequence.generator.config.decoder_start_token_id, + num_beams=4, + num_return_sequences=1, + max_length=10, + ) + + # sequence generate test + output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True) + output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True) + + # Expected outputs as given by model at integration time. + EXPECTED_OUTPUT_TEXT_1 = '"I Know Him So Well"' + EXPECTED_OUTPUT_TEXT_2 = '"Howl" chronicles the' + + self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1) + self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2) + + @slow + def test_rag_sequence_generate_beam(self): + rag_config = self.get_rag_config() + rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") + rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( + "facebook/dpr-question_encoder-single-nq-base" + ) + rag_retriever = RagRetriever( + rag_config, + question_encoder_tokenizer=rag_question_encoder_tokenizer, + generator_tokenizer=rag_decoder_tokenizer, + ) + + rag_token = self.sequence_model + rag_token.set_retriever(rag_retriever) + + input_ids = rag_question_encoder_tokenizer( + "who sings does he love me with reba", return_tensors="pt" + ).input_ids + + input_ids = input_ids.to(torch_device) + + output_ids = rag_token.generate( + input_ids, + decoder_start_token_id=rag_token.generator.config.decoder_start_token_id, + num_beams=2, + num_return_sequences=2, + ) + # sequence generate test + output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True) + output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True) + + # Expected outputs as given by model at integration time. + EXPECTED_OUTPUT_TEXT_1 = """ ABBA / small label like Playboy Records did not have the distribution resources to meet the demand for the single from retailers and radio programmers. The foursome decided to record their first album together in late 1972, and sessions began on 26 September 1972. The women shared lead vocals on "Nina, Pretty Ballerina" that day.""" + EXPECTED_OUTPUT_TEXT_2 = """ ABBA / small label like Playboy Records did not have the distribution resources to meet the demand for the single from retailers and radio programmers. The foursome decided to record their first album together in late 1972, and sessions began on 26 September 1972. The women shared lead vocals on "Nina, Pretty Ballerina" (a top ten hit in Austria)""" + + self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1) + self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2) + + +@require_torch +@require_retrieval +class RagModelSaveLoadTests(unittest.TestCase): + def get_rag_config(self): + question_encoder_config = AutoConfig.from_pretrained("facebook/dpr-question_encoder-single-nq-base") + generator_config = AutoConfig.from_pretrained("facebook/bart-large-cnn") + return RagConfig.from_question_encoder_generator_configs( + question_encoder_config, + generator_config, + bos_token_id=0, + decoder_start_token_id=2, + eos_token_id=2, + is_encoder_decoder=True, + pad_token_id=1, + vocab_size=50264, + title_sep=" / ", + doc_sep=" // ", + n_docs=5, + max_combined_length=300, + dataset="wiki_dpr", + dataset_split="train", + index_name="exact", + index_path=None, + use_dummy_dataset=True, + retrieval_vector_size=768, + retrieval_batch_size=8, + ) + + @slow + def test_rag_sequence_from_pretrained(self): + rag_config = self.get_rag_config() + rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") + rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( + "facebook/dpr-question_encoder-single-nq-base" + ) + rag_retriever = RagRetriever( + rag_config, + question_encoder_tokenizer=rag_question_encoder_tokenizer, + generator_tokenizer=rag_decoder_tokenizer, + ) + + input_ids = rag_question_encoder_tokenizer( + "who sings does he love me with reba", return_tensors="pt" + ).input_ids + decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="pt").input_ids + + input_ids = input_ids.to(torch_device) + decoder_input_ids = decoder_input_ids.to(torch_device) + + with tempfile.TemporaryDirectory() as tmp_dirname: + rag_sequence = RagSequenceForGeneration.from_pretrained_question_encoder_generator( + "facebook/dpr-question_encoder-single-nq-base", + "facebook/bart-large-cnn", + retriever=rag_retriever, + config=rag_config, + ).to(torch_device) + # check that the from pretrained methods work + rag_sequence.save_pretrained(tmp_dirname) + rag_sequence.from_pretrained(tmp_dirname, retriever=rag_retriever) + rag_sequence.to(torch_device) + + with torch.no_grad(): + output = rag_sequence( + input_ids, + labels=decoder_input_ids, + ) + + loss_pretrained = output.loss + del rag_sequence + + question_encoder = AutoModel.from_pretrained("facebook/dpr-question_encoder-single-nq-base") + generator = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn") + rag_sequence = RagSequenceForGeneration( + config=rag_config, question_encoder=question_encoder, generator=generator, retriever=rag_retriever + ) + rag_sequence.to(torch_device) + + with torch.no_grad(): + output = rag_sequence( + input_ids, + labels=decoder_input_ids, + ) + + loss_init = output.loss + + self.assertAlmostEqual(loss_pretrained.item(), loss_init.item(), places=4) + + @slow + def test_rag_token_from_pretrained(self): + rag_config = self.get_rag_config() + rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") + rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( + "facebook/dpr-question_encoder-single-nq-base" + ) + rag_retriever = RagRetriever( + rag_config, + question_encoder_tokenizer=rag_question_encoder_tokenizer, + generator_tokenizer=rag_decoder_tokenizer, + ) + + input_ids = rag_question_encoder_tokenizer( + "who sings does he love me with reba", return_tensors="pt" + ).input_ids + decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="pt").input_ids + + input_ids = input_ids.to(torch_device) + decoder_input_ids = decoder_input_ids.to(torch_device) + + with tempfile.TemporaryDirectory() as tmp_dirname: + rag_token = RagTokenForGeneration.from_pretrained_question_encoder_generator( + "facebook/dpr-question_encoder-single-nq-base", + "facebook/bart-large-cnn", + retriever=rag_retriever, + config=rag_config, + ).to(torch_device) + # check that the from pretrained methods work + rag_token.save_pretrained(tmp_dirname) + rag_token.from_pretrained(tmp_dirname, retriever=rag_retriever) + rag_token.to(torch_device) + + with torch.no_grad(): + output = rag_token( + input_ids, + labels=decoder_input_ids, + ) + + loss_pretrained = output.loss + del rag_token + + question_encoder = AutoModel.from_pretrained("facebook/dpr-question_encoder-single-nq-base") + generator = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn") + rag_token = RagTokenForGeneration( + config=rag_config, question_encoder=question_encoder, generator=generator, retriever=rag_retriever + ) + rag_token.to(torch_device) + + with torch.no_grad(): + output = rag_token( + input_ids, + labels=decoder_input_ids, + ) + + loss_init = output.loss + + self.assertAlmostEqual(loss_pretrained.item(), loss_init.item(), places=4) diff --git a/tests/test_modeling_t5.py b/tests/test_modeling_t5.py index 4c411d844..c5e3ec9d1 100644 --- a/tests/test_modeling_t5.py +++ b/tests/test_modeling_t5.py @@ -35,28 +35,53 @@ if is_torch_available(): class T5ModelTester: - def __init__(self, parent): + def __init__( + self, + parent, + vocab_size=99, + n_positions=14, + batch_size=13, + encoder_seq_length=7, + decoder_seq_length=9, + # For common tests + seq_length=7, + is_training=True, + use_attention_mask=True, + use_labels=True, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + d_ff=37, + relative_attention_num_buckets=8, + dropout_rate=0.1, + initializer_factor=0.002, + eos_token_id=1, + pad_token_id=0, + decoder_start_token_id=0, + scope=None, + ): + self.parent = parent - self.batch_size = 13 - self.encoder_seq_length = 7 - self.decoder_seq_length = 9 + self.batch_size = batch_size + self.encoder_seq_length = encoder_seq_length + self.decoder_seq_length = decoder_seq_length # For common tests self.seq_length = self.decoder_seq_length - self.is_training = True - self.use_attention_mask = True - self.use_labels = True - self.vocab_size = 99 - self.n_positions = 14 - self.hidden_size = 32 - self.num_hidden_layers = 5 - self.num_attention_heads = 4 - self.d_ff = 37 - self.relative_attention_num_buckets = 8 - self.dropout_rate = 0.1 - self.initializer_factor = 0.002 - self.eos_token_id = 1 - self.pad_token_id = 0 - self.decoder_start_token_id = 0 + self.is_training = is_training + self.use_attention_mask = use_attention_mask + self.use_labels = use_labels + self.vocab_size = vocab_size + self.n_positions = n_positions + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.d_ff = d_ff + self.relative_attention_num_buckets = relative_attention_num_buckets + self.dropout_rate = dropout_rate + self.initializer_factor = initializer_factor + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.decoder_start_token_id = decoder_start_token_id self.scope = None def prepare_config_and_inputs(self): diff --git a/tests/test_retrieval_rag.py b/tests/test_retrieval_rag.py new file mode 100644 index 000000000..120a7a8c7 --- /dev/null +++ b/tests/test_retrieval_rag.py @@ -0,0 +1,223 @@ +import json +import os +import pickle +import shutil +import tempfile +from unittest import TestCase +from unittest.mock import patch + +import numpy as np +from datasets import Dataset + +import faiss +from transformers.configuration_bart import BartConfig +from transformers.configuration_dpr import DPRConfig +from transformers.configuration_rag import RagConfig +from transformers.retrieval_rag import RagRetriever +from transformers.testing_utils import require_datasets, require_faiss, require_torch +from transformers.tokenization_bart import BartTokenizer +from transformers.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES +from transformers.tokenization_dpr import DPRQuestionEncoderTokenizer +from transformers.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES + + +@require_faiss +@require_datasets +class RagRetrieverTest(TestCase): + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + self.retrieval_vector_size = 8 + + # DPR tok + vocab_tokens = [ + "[UNK]", + "[CLS]", + "[SEP]", + "[PAD]", + "[MASK]", + "want", + "##want", + "##ed", + "wa", + "un", + "runn", + "##ing", + ",", + "low", + "lowest", + ] + dpr_tokenizer_path = os.path.join(self.tmpdirname, "dpr_tokenizer") + os.makedirs(dpr_tokenizer_path, exist_ok=True) + self.vocab_file = os.path.join(dpr_tokenizer_path, DPR_VOCAB_FILES_NAMES["vocab_file"]) + with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer: + vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) + + # BART tok + vocab = [ + "l", + "o", + "w", + "e", + "r", + "s", + "t", + "i", + "d", + "n", + "\u0120", + "\u0120l", + "\u0120n", + "\u0120lo", + "\u0120low", + "er", + "\u0120lowest", + "\u0120newer", + "\u0120wider", + "", + ] + vocab_tokens = dict(zip(vocab, range(len(vocab)))) + merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""] + self.special_tokens_map = {"unk_token": ""} + + bart_tokenizer_path = os.path.join(self.tmpdirname, "bart_tokenizer") + os.makedirs(bart_tokenizer_path, exist_ok=True) + self.vocab_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["vocab_file"]) + self.merges_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["merges_file"]) + with open(self.vocab_file, "w", encoding="utf-8") as fp: + fp.write(json.dumps(vocab_tokens) + "\n") + with open(self.merges_file, "w", encoding="utf-8") as fp: + fp.write("\n".join(merges)) + + def get_dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer: + return DPRQuestionEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer")) + + def get_bart_tokenizer(self) -> BartTokenizer: + return BartTokenizer.from_pretrained(os.path.join(self.tmpdirname, "bart_tokenizer")) + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + def get_dummy_hf_index_retriever(self): + dataset = Dataset.from_dict( + { + "id": ["0", "1"], + "text": ["foo", "bar"], + "title": ["Foo", "Bar"], + "embeddings": [np.ones(self.retrieval_vector_size), 2 * np.ones(self.retrieval_vector_size)], + } + ) + dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT) + config = RagConfig( + retrieval_vector_size=self.retrieval_vector_size, + question_encoder=DPRConfig().to_dict(), + generator=BartConfig().to_dict(), + ) + with patch("transformers.retrieval_rag.load_dataset") as mock_load_dataset: + mock_load_dataset.return_value = dataset + retriever = RagRetriever( + config, + question_encoder_tokenizer=self.get_dpr_tokenizer(), + generator_tokenizer=self.get_bart_tokenizer(), + ) + return retriever + + def get_dummy_legacy_index_retriever(self): + dataset = Dataset.from_dict( + { + "id": ["0", "1"], + "text": ["foo", "bar"], + "title": ["Foo", "Bar"], + "embeddings": [np.ones(self.retrieval_vector_size + 1), 2 * np.ones(self.retrieval_vector_size + 1)], + } + ) + dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT) + + index_file_name = os.path.join(self.tmpdirname, "hf_bert_base.hnswSQ8_correct_phi_128.c_index") + dataset.save_faiss_index("embeddings", index_file_name + ".index.dpr") + pickle.dump(dataset["id"], open(index_file_name + ".index_meta.dpr", "wb")) + + passages_file_name = os.path.join(self.tmpdirname, "psgs_w100.tsv.pkl") + passages = {sample["id"]: [sample["text"], sample["title"]] for sample in dataset} + pickle.dump(passages, open(passages_file_name, "wb")) + + config = RagConfig( + retrieval_vector_size=self.retrieval_vector_size, + question_encoder=DPRConfig().to_dict(), + generator=BartConfig().to_dict(), + index_name="legacy", + index_path=self.tmpdirname, + passages_path=self.tmpdirname, + ) + retriever = RagRetriever( + config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer() + ) + return retriever + + def test_hf_index_retriever_retrieve(self): + n_docs = 1 + retriever = self.get_dummy_hf_index_retriever() + hidden_states = np.array( + [np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32 + ) + retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs) + self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size)) + self.assertEqual(len(doc_dicts), 2) + self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"]) + self.assertEqual(len(doc_dicts[0]["id"]), n_docs) + self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc + self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc + self.assertListEqual(list(doc_ids), [1, 0]) + + def test_legacy_index_retriever_retrieve(self): + n_docs = 1 + retriever = self.get_dummy_legacy_index_retriever() + hidden_states = np.array( + [np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32 + ) + retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs) + self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size)) + self.assertEqual(len(doc_dicts), 2) + self.assertEqual(sorted(doc_dicts[0]), ["text", "title"]) + self.assertEqual(len(doc_dicts[0]["text"]), n_docs) + self.assertEqual(doc_dicts[0]["text"][0], "bar") # max inner product is reached with second doc + self.assertEqual(doc_dicts[1]["text"][0], "foo") # max inner product is reached with first doc + self.assertListEqual(list(doc_ids), [1, 0]) + + @require_torch + def test_hf_index_retriever_call(self): + import torch + + n_docs = 1 + retriever = self.get_dummy_hf_index_retriever() + question_input_ids = [[5, 7], [10, 11]] + hidden_states = np.array( + [np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32 + ) + out = retriever(question_input_ids, hidden_states, prefix=retriever.config.generator.prefix, n_docs=n_docs) + context_input_ids, context_attention_mask, retrieved_doc_embeds = ( + out["context_input_ids"], + out["context_attention_mask"], + out["retrieved_doc_embeds"], + ) + self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size)) + self.assertIsInstance(context_input_ids, list) + self.assertIsInstance(context_attention_mask, list) + self.assertIsInstance(retrieved_doc_embeds, np.ndarray) + + out = retriever( + question_input_ids, + hidden_states, + prefix=retriever.config.generator.prefix, + n_docs=n_docs, + return_tensors="pt", + ) + context_input_ids, context_attention_mask, retrieved_doc_embeds, doc_ids = ( # noqa: F841 + out["context_input_ids"], + out["context_attention_mask"], + out["retrieved_doc_embeds"], + out["doc_ids"], + ) + self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size)) + self.assertIsInstance(context_input_ids, torch.Tensor) + self.assertIsInstance(context_attention_mask, torch.Tensor) + self.assertIsInstance(retrieved_doc_embeds, torch.Tensor) diff --git a/tests/test_tokenization_rag.py b/tests/test_tokenization_rag.py new file mode 100644 index 000000000..ab18af490 --- /dev/null +++ b/tests/test_tokenization_rag.py @@ -0,0 +1,110 @@ +import json +import os +import shutil +import tempfile +from unittest import TestCase + +from transformers.configuration_bart import BartConfig +from transformers.configuration_dpr import DPRConfig +from transformers.file_utils import is_datasets_available, is_faiss_available, is_torch_available +from transformers.testing_utils import require_datasets, require_faiss, require_torch +from transformers.tokenization_bart import BartTokenizer +from transformers.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES +from transformers.tokenization_dpr import DPRQuestionEncoderTokenizer +from transformers.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES + + +if is_torch_available() and is_datasets_available() and is_faiss_available(): + from transformers.configuration_rag import RagConfig + from transformers.tokenization_rag import RagTokenizer + + +@require_faiss +@require_datasets +@require_torch +class RagTokenizerTest(TestCase): + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + self.retrieval_vector_size = 8 + + # DPR tok + vocab_tokens = [ + "[UNK]", + "[CLS]", + "[SEP]", + "[PAD]", + "[MASK]", + "want", + "##want", + "##ed", + "wa", + "un", + "runn", + "##ing", + ",", + "low", + "lowest", + ] + dpr_tokenizer_path = os.path.join(self.tmpdirname, "dpr_tokenizer") + os.makedirs(dpr_tokenizer_path, exist_ok=True) + self.vocab_file = os.path.join(dpr_tokenizer_path, DPR_VOCAB_FILES_NAMES["vocab_file"]) + with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer: + vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) + + # BART tok + vocab = [ + "l", + "o", + "w", + "e", + "r", + "s", + "t", + "i", + "d", + "n", + "\u0120", + "\u0120l", + "\u0120n", + "\u0120lo", + "\u0120low", + "er", + "\u0120lowest", + "\u0120newer", + "\u0120wider", + "", + ] + vocab_tokens = dict(zip(vocab, range(len(vocab)))) + merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""] + self.special_tokens_map = {"unk_token": ""} + + bart_tokenizer_path = os.path.join(self.tmpdirname, "bart_tokenizer") + os.makedirs(bart_tokenizer_path, exist_ok=True) + self.vocab_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["vocab_file"]) + self.merges_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["merges_file"]) + with open(self.vocab_file, "w", encoding="utf-8") as fp: + fp.write(json.dumps(vocab_tokens) + "\n") + with open(self.merges_file, "w", encoding="utf-8") as fp: + fp.write("\n".join(merges)) + + def get_dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer: + return DPRQuestionEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer")) + + def get_bart_tokenizer(self) -> BartTokenizer: + return BartTokenizer.from_pretrained(os.path.join(self.tmpdirname, "bart_tokenizer")) + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + def test_save_load_pretrained_with_saved_config(self): + + save_dir = os.path.join(self.tmpdirname, "rag_tokenizer") + rag_config = RagConfig(question_encoder=DPRConfig().to_dict(), generator=BartConfig().to_dict()) + rag_tokenizer = RagTokenizer(question_encoder=self.get_dpr_tokenizer(), generator=self.get_bart_tokenizer()) + rag_config.save_pretrained(save_dir) + rag_tokenizer.save_pretrained(save_dir) + new_rag_tokenizer = RagTokenizer.from_pretrained(save_dir, config=rag_config) + self.assertIsInstance(new_rag_tokenizer.question_encoder, DPRQuestionEncoderTokenizer) + self.assertEqual(new_rag_tokenizer.question_encoder.vocab, rag_tokenizer.question_encoder.vocab) + self.assertIsInstance(new_rag_tokenizer.generator, BartTokenizer) + self.assertEqual(new_rag_tokenizer.generator.encoder, rag_tokenizer.generator.encoder)