RAG (#6813)
* added rag WIP * path fix * Formatting / renaming prior to actual work * added rag WIP * path fix * Formatting / renaming prior to actual work * added rag WIP * path fix * Formatting / renaming prior to actual work * added rag WIP * Formatting / renaming prior to actual work * First commit * improve comments * Retrieval evaluation scripts * refactor to include modeling outputs + MPI retriever * Fix rag-token model + refactor * Various fixes + finetuning logic * use_bos fix * Retrieval refactor * Finetuning refactoring and cleanup * Add documentation and cleanup * Remove set_up_rag_env.sh file * Fix retrieval wit HF index * Fix import errors * Fix quality errors * Refactor as per suggestions in https://github.com/huggingface/transformers/pull/6813#issuecomment-687208867 * fix quality * Fix RAG Sequence generation * minor cleanup plus initial tests * fix test * fix tests 2 * Comments fix * post-merge fixes * Improve readme + post-rebase refactor * Extra dependencied for tests * Fix tests * Fix tests 2 * Refactor test requirements * Fix tests 3 * Post-rebase refactor * rename nlp->datasets * RAG integration tests * add tokenizer to slow integration test and allow retriever to run on cpu * add tests; fix position ids warning * change structure * change structure * add from encoder generator * save working solution * make all integration tests pass * add RagTokenizer.save/from_pretrained and RagRetriever.save/from_pretrained * don't save paths * delete unnecessary imports * pass config to AutoTokenizer.from_pretrained for Rag tokenizers * init wiki_dpr only once * hardcode legacy index and passages paths (todo: add the right urls) * finalize config * finalize retriver api and config api * LegacyIndex index download refactor * add dpr to autotokenizer * make from pretrained more flexible * fix ragfortokengeneration * small name changes in tokenizer * add labels to models * change default index name * add retrieval tests * finish token generate * align test with previous version and make all tests pass * add tests * finalize tests * implement thoms suggestions * add first version of test * make first tests work * make retriever platform agnostic * naming * style * add legacy index URL * docstrings + simple retrieval test for distributed * clean model api * add doc_ids to retriever's outputs * fix retrieval tests * finish model outputs * finalize model api * fix generate problem for rag * fix generate for other modles * fix some tests * save intermediate * set generate to default * big refactor generate * delete rag_api * correct pip faiss install * fix auto tokenization test * fix faiss install * fix test * move the distributed logic to examples * model page * docs * finish tests * fix dependencies * fix import in __init__ * Refactor eval_rag and finetune scripts * start docstring * add psutil to test * fix tf test * move require torch to top * fix retrieval test * align naming * finish automodel * fix repo consistency * test ragtokenizer save/load * add rag model output docs * fix ragtokenizer save/load from pretrained * fix tokenizer dir * remove torch in retrieval * fix docs * fixe finetune scripts * finish model docs * finish docs * remove auto model for now * add require torch * remove solved todos * integrate sylvains suggestions * sams comments * correct mistake on purpose * improve README * Add generation test cases * fix rag token * clean token generate * fix test * add note to test * fix attention mask * add t5 test for rag * Fix handling prefix in finetune.py * don't overwrite index_name Co-authored-by: Patrick Lewis <plewis@fb.com> Co-authored-by: Aleksandra Piktus <piktus@devfair0141.h2.fair> Co-authored-by: Aleksandra Piktus <piktus@learnfair5102.h2.fair> Co-authored-by: Aleksandra Piktus <piktus@learnfair5067.h2.fair> Co-authored-by: Your Name <you@example.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Quentin Lhoest <lhoest.q@gmail.com>
This commit is contained in:
Родитель
1ee2194fb6
Коммит
c754c41c61
|
@ -11,6 +11,7 @@ __pycache__/
|
|||
# tests and logs
|
||||
tests/fixtures
|
||||
logs/
|
||||
lightning_logs/
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
|
@ -139,6 +140,7 @@ runs
|
|||
/wandb
|
||||
/examples/runs
|
||||
/examples/**/*.args
|
||||
/examples/rag/sweep
|
||||
|
||||
# data
|
||||
/data
|
||||
|
|
|
@ -231,6 +231,7 @@ conversion utilities for the following models:
|
|||
model_doc/lxmert
|
||||
model_doc/bertgeneration
|
||||
model_doc/layoutlm
|
||||
model_doc/rag
|
||||
internal/modeling_utils
|
||||
internal/tokenization_utils
|
||||
internal/pipelines_utils
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
RAG
|
||||
----------------------------------------------------
|
||||
|
||||
Overview
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Retrieval-augmented generation ("RAG") models combine the powers of pretrained dense retrieval (DPR) and Seq2Seq models.
|
||||
RAG models retrieve docs, pass them to a seq2seq model, then marginalize to generate outputs.
|
||||
The retriever and seq2seq modules are initialized from pretrained models, and fine-tuned jointly, allowing both retrieval and generation to adapt to downstream tasks.
|
||||
|
||||
It is based on the paper `Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks <https://arxiv.org/abs/2005.11401>`__ by Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Large pre-trained language models have been shown to store factual knowledge
|
||||
in their parameters, and achieve state-of-the-art results when fine-tuned on
|
||||
downstream NLP tasks. However, their ability to access and precisely manipulate
|
||||
knowledge is still limited, and hence on knowledge-intensive tasks, their
|
||||
performance lags behind task-specific architectures. Additionally, providing
|
||||
provenance for their decisions and updating their world knowledge remain open
|
||||
research problems. Pre-trained models with a differentiable access mechanism to
|
||||
explicit nonparametric memory can overcome this issue, but have so far been only
|
||||
investigated for extractive downstream tasks. We explore a general-purpose
|
||||
fine-tuning recipe for retrieval-augmented generation (RAG) — models which combine
|
||||
pre-trained parametric and non-parametric memory for language generation. We
|
||||
introduce RAG models where the parametric memory is a pre-trained seq2seq model and
|
||||
the non-parametric memory is a dense vector index of Wikipedia, accessed with
|
||||
a pre-trained neural retriever. We compare two RAG formulations, one which
|
||||
conditions on the same retrieved passages across the whole generated sequence, the
|
||||
other can use different passages per token. We fine-tune and evaluate our models
|
||||
on a wide range of knowledge-intensive NLP tasks and set the state-of-the-art
|
||||
on three open domain QA tasks, outperforming parametric seq2seq models and
|
||||
task-specific retrieve-and-extract architectures. For language generation tasks, we
|
||||
find that RAG models generate more specific, diverse and factual language than a
|
||||
state-of-the-art parametric-only seq2seq baseline.*
|
||||
|
||||
|
||||
|
||||
RagConfig
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.RagConfig
|
||||
:members:
|
||||
|
||||
|
||||
RagTokenizer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.RagTokenizer
|
||||
:members:
|
||||
|
||||
|
||||
Rag specific outputs
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.modeling_rag.RetrievAugLMMarginOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.modeling_rag.RetrievAugLMOutput
|
||||
:members:
|
||||
|
||||
|
||||
RAGRetriever
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.RagRetriever
|
||||
:members:
|
||||
|
||||
|
||||
RagModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.RagModel
|
||||
:members: forward
|
||||
|
||||
|
||||
RagSequenceForGeneration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.RagSequenceForGeneration
|
||||
:members: forward, generate
|
||||
|
||||
|
||||
RagTokenForGeneration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.RagTokenForGeneration
|
||||
:members: forward, generate
|
|
@ -654,7 +654,7 @@ DPR
|
|||
<a href="https://huggingface.co/models?filter=dpr">
|
||||
<img alt="Models" src="https://img.shields.io/badge/All_model_pages-dpr-blueviolet">
|
||||
</a>
|
||||
<a href="model_doc/ctrl.dpr">
|
||||
<a href="model_doc/dpr.html">
|
||||
<img alt="Doc" src="https://img.shields.io/badge/Model_documentation-dpr-blueviolet">
|
||||
</a>
|
||||
|
||||
|
@ -672,6 +672,27 @@ DPR consists in three models:
|
|||
|
||||
DPR's pipeline (not implemented yet) uses a retrieval step to find the top k contexts given a certain question, and then it calls the reader with the question and the retrieved documents to get the answer.
|
||||
|
||||
RAG
|
||||
----------------------------------------------
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<a href="https://huggingface.co/models?filter=rag">
|
||||
<img alt="Models" src="https://img.shields.io/badge/All_model_pages-rag-blueviolet">
|
||||
</a>
|
||||
<a href="model_doc/rag.html">
|
||||
<img alt="Doc" src="https://img.shields.io/badge/Model_documentation-rag-blueviolet">
|
||||
</a>
|
||||
|
||||
`Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks <https://arxiv.org/abs/2005.11401>`_,
|
||||
Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela
|
||||
|
||||
Retrieval-augmented generation ("RAG") models combine the powers of pretrained dense retrieval (DPR) and Seq2Seq models.
|
||||
RAG models retrieve docs, pass them to a seq2seq model, then marginalize to generate outputs.
|
||||
The retriever and seq2seq modules are initialized from pretrained models, and fine-tuned jointly, allowing both retrieval and generation to adapt to downstream tasks.
|
||||
|
||||
The two models RAG-Token and RAG-Sequence are available for generation.
|
||||
|
||||
More technical aspects
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
|
|
@ -366,6 +366,8 @@ def generic_train(
|
|||
if args.gpus > 1:
|
||||
train_params["distributed_backend"] = "ddp"
|
||||
|
||||
train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
|
||||
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
weights_summary=None,
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
import datasets
|
||||
import faiss
|
||||
import numpy as np
|
||||
import streamlit as st
|
||||
import torch
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
import faiss
|
||||
import transformers
|
||||
from eli5_utils import (
|
||||
embed_questions_for_retrieval,
|
||||
|
|
|
@ -5,7 +5,6 @@ from random import choice, randint
|
|||
from time import time
|
||||
|
||||
import datasets # noqa: F401
|
||||
import faiss # noqa: F401
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
@ -15,6 +14,7 @@ from elasticsearch.helpers import bulk, streaming_bulk # noqa: F401
|
|||
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
|
||||
from tqdm import tqdm
|
||||
|
||||
import faiss # noqa: F401
|
||||
from transformers import AdamW, AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer, get_linear_schedule_with_warmup
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
# Intro
|
||||
RAG is a seq2seq model which encapsulates two core components: a question encoder and a generator.
|
||||
During a forward pass, we encode the input with the question encoder and pass it
|
||||
to the retriever to extract relevant context documents. The documents are then prepended to the input.
|
||||
Such contextualized inputs is passed to the generator.
|
||||
|
||||
The question encoder can be any `autoencoding` model, preferably :obj:`~transformers.DPRQuestionEncoder`, and the generator can be any `seq2seq` model, preferably :obj:`~transformers.BartForConditionalGeneration`.
|
||||
|
||||
The model can be initialized with a :obj:`~transformers.RagRetriever` for end-to-end generation or used in combination with the outputs of a retriever in multiple steps - see examples for more details.
|
||||
The model is compatible any `autoencoding` model as the ``question_encoder`` and any `seq2seq` model with language model head as the ``generator``.
|
||||
The model has been tested with :class:`~transformers.DPRQuestionEncoder` as the ``question_encoder`` and :class:`~transformers.BartForConditionalGeneration` or :class:`~transformers.T5ForConditionalGeneration` as the ``generator``.
|
||||
|
||||
RAG models were released with the paper `Retrieval-Augmented Generation for
|
||||
Knowledge-Intensive NLP Tasks <https://arxiv.org/abs/2005.11401>`_ by Patrick Lewis, Ethan Perez, Aleksandra Piktus et al.
|
||||
|
||||
|
||||
# Finetuning
|
||||
Our finetuning logic is based on scripts from [`examples/seq2seq`](https://github.com/huggingface/transformers/tree/master/examples/seq2seq).
|
||||
Follow instructions there regarding data preprocessing. A sample finetuning command:
|
||||
|
||||
```
|
||||
python examples/rag/finetune.py \
|
||||
--data_dir $DATA_DIR \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--model_name_or_path $MODEL_NAME_OR_PATH \
|
||||
--model_type rag_sequence \
|
||||
--fp16 \
|
||||
--gpus 8
|
||||
```
|
||||
|
||||
|
||||
# Evaluation
|
||||
Apart from the parameters specifying the model to evaluate and some extra parameters, the evaluation script expects paths to two files:
|
||||
- `evaluation_set` - a path to a file specifying the evaluation dataset, a single datapoint per line, e.g.
|
||||
```who is the owner of reading football club```
|
||||
- `gold_data_path` - a path to a file contaning ground truth answers for datapoints from the `evaluation_set`.
|
||||
|
||||
We expect the following formats of the gold data file:
|
||||
|
||||
- for e2e evaluation, we support two formats of the gold file:
|
||||
- `qa` - where a single line in the following format: input [tab] output_list, e.g.:
|
||||
```
|
||||
who is the owner of reading football club ['Xiu Li Dai', 'Dai Yongge', 'Dai Xiuli', 'Yongge Dai']
|
||||
```
|
||||
- `ans` - where a single line of the gold file contains the expected output string, e.g.:
|
||||
```
|
||||
Xiu Li Dai
|
||||
```
|
||||
|
||||
- for retrieval evaluation, we expect a tab-separated list of Wikipedia page titles constituting positive contexts for a given query, e.g. given a question `who sings does he love me with reba`, a line with ground truth retrieval data could look as follows:
|
||||
```
|
||||
Does He Love You Does He Love You Red Sandy Spika dress of Reba McEntire Greatest Hits Volume Two (Reba McEntire album) Shoot for the Moon (album)
|
||||
```
|
||||
|
||||
## Retrieval evaluation
|
||||
|
||||
We demonstrate how to evaluate retrieval against DPR evaluation data. You can download respective files from links listed [here](https://github.com/facebookresearch/DPR/blob/master/data/download_data.py#L39-L45).
|
||||
|
||||
1. Download and unzip the gold data file. We use the `biencoder-nq-dev` from https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-dev.json.gz.
|
||||
2. Parse the unziped file using the `parse_dpr_relevance_data.py`
|
||||
```
|
||||
python examples/rag/parse_dpr_relevance_data.py --src_path path/to/unziped/biencoder-nq-dev.json --evaluation_set path/to/output/biencoder-nq-dev.questions --gold_data_path path/to/output/biencoder-nq-dev.pages
|
||||
```
|
||||
3. Run evaluation:
|
||||
```
|
||||
python examples/rag/eval_rag.py \
|
||||
--model_name_or_path $MODEL_NAME_OR_PATH \ # model name or path of the model we're evaluating
|
||||
--model_type rag_sequence \ # RAG model type (rag_token or rag_sequence)
|
||||
--evaluation_set path/to/output/biencoder-nq-dev.questions \ # an input dataset for evaluation
|
||||
--gold_data_path path/to/output/biencoder-nq-dev.pages \ # a dataset containing ground truth answers for samples from the evaluation_set
|
||||
--predictions_path path/to/retrieval_preds.tsv \ # name of file in which predictions will be stored
|
||||
--eval_mode retrieval \ # indicates whether we're performing retrieval evaluation or e2e evaluation
|
||||
--recalculate # if predictions_filename already exists, and this option is set - we regenerate the answers, otherwise we reuse the predicsion file to calculate metrics.
|
||||
```
|
||||
|
||||
|
||||
## End-to-end evaluation
|
||||
```
|
||||
python examples/rag/eval_rag.py \
|
||||
--model_name_or_path $MODEL_NAME_OR_PATH \
|
||||
--model_type rag_sequence \
|
||||
--evaluation_set path/to/test.source \
|
||||
--gold_data_path path/to/gold_data \
|
||||
--predictions_path path/to/e2e_preds.txt \
|
||||
--eval_mode e2e \ # indicates whether we're performing retrieval evaluation or e2e evaluation (default)
|
||||
--n_docs 5 \ # You can experiment with retrieving different number of documents at evaluation time
|
||||
--print_predictions
|
||||
```
|
|
@ -0,0 +1,30 @@
|
|||
import logging
|
||||
import os
|
||||
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_checkpoint_callback(output_dir, metric):
|
||||
"""Saves the best model by validation EM score."""
|
||||
if metric == "rouge2":
|
||||
exp = "{val_avg_rouge2:.4f}-{step_count}"
|
||||
elif metric == "bleu":
|
||||
exp = "{val_avg_bleu:.4f}-{step_count}"
|
||||
elif metric == "em":
|
||||
exp = "{val_avg_em:.4f}-{step_count}"
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this function."
|
||||
)
|
||||
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
filepath=os.path.join(output_dir, exp),
|
||||
monitor=f"val_{metric}",
|
||||
mode="max",
|
||||
save_top_k=3,
|
||||
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
|
||||
)
|
||||
return checkpoint_callback
|
|
@ -0,0 +1,135 @@
|
|||
import logging
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import psutil
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from transformers import RagRetriever
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RagPyTorchDistributedRetriever(RagRetriever):
|
||||
"""
|
||||
A distributed retriever built on top of the ``torch.distributed`` communication package. During training all workers
|
||||
initalize their own instance of the retriever, however, only the main worker loads the index into memory. The index is stored
|
||||
in cpu memory. The index will also work well in a non-distributed setup.
|
||||
|
||||
Args:
|
||||
config (:class:`~transformers.RagConfig`):
|
||||
The configuration of the RAG model this Retriever is used with. Contains parameters indicating which ``Index`` to build.
|
||||
question_encoder_tokenizer (:class:`~transformers.PretrainedTokenizer`):
|
||||
The tokenizer that was used to tokenize the question.
|
||||
It is used to decode the question and then use the generator_tokenizer.
|
||||
generator_tokenizer (:class:`~transformers.PretrainedTokenizer`):
|
||||
The tokenizer used for the generator part of the RagModel.
|
||||
"""
|
||||
|
||||
_init_retrieval = False
|
||||
|
||||
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer):
|
||||
super().__init__(
|
||||
config, question_encoder_tokenizer=question_encoder_tokenizer, generator_tokenizer=generator_tokenizer
|
||||
)
|
||||
|
||||
self.process_group = None
|
||||
|
||||
def init_retrieval(self, distributed_port: int):
|
||||
"""
|
||||
Retriever initalization function, needs to be called from the training process. The function sets some common parameters
|
||||
and environment variables. On top of that, (only) the main process in the process group loads the index into memory.
|
||||
|
||||
Args:
|
||||
distributed_port (:obj:`int`):
|
||||
The port on which the main communication of the training run is carried out. We set the port for retrieval-related
|
||||
communication as ``distributed_port + 1``.
|
||||
"""
|
||||
|
||||
logger.info("initializing retrieval")
|
||||
|
||||
# initializing a separate process group for retrievel as the default
|
||||
# nccl backend doesn't support gather/scatter operations while gloo
|
||||
# is too slow to replace nccl for the core gpu communication
|
||||
if dist.is_initialized():
|
||||
logger.info("dist initialized")
|
||||
# needs to be set manually
|
||||
os.environ["GLOO_SOCKET_IFNAME"] = self._infer_socket_ifname()
|
||||
# avoid clash with the NCCL port
|
||||
os.environ["MASTER_PORT"] = str(distributed_port + 1)
|
||||
self.process_group = dist.new_group(ranks=None, backend="gloo")
|
||||
|
||||
# initialize retriever only on the main worker
|
||||
if not dist.is_initialized() or self._is_main():
|
||||
logger.info("dist not initialized / main")
|
||||
self.index.init_index()
|
||||
|
||||
# all processes wait untill the retriever is initialized by the main process
|
||||
if dist.is_initialized():
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
def _is_main(self):
|
||||
return dist.get_rank(group=self.process_group) == 0
|
||||
|
||||
def _scattered(self, scatter_list, target_shape, target_type=torch.float32):
|
||||
target_tensor = torch.empty(target_shape, dtype=target_type)
|
||||
dist.scatter(target_tensor, src=0, scatter_list=scatter_list, group=self.process_group)
|
||||
return target_tensor
|
||||
|
||||
def _infer_socket_ifname(self):
|
||||
addrs = psutil.net_if_addrs()
|
||||
# a hacky way to deal with varying network interface names
|
||||
ifname = next((addr for addr in addrs if addr.startswith("e")), None)
|
||||
return ifname
|
||||
|
||||
def retrieve(self, question_hidden_states: np.ndarray, n_docs: int) -> Tuple[np.ndarray, List[dict]]:
|
||||
"""
|
||||
Retrieves documents for specified ``question_hidden_states``. The main process, which has the access to the index stored in memory, gathers queries
|
||||
from all the processes in the main training process group, performs the retrieval and scatters back the results.
|
||||
|
||||
Args:
|
||||
question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`):
|
||||
A batch of query vectors to retrieve with.
|
||||
n_docs (:obj:`int`):
|
||||
The number of docs retrieved per query.
|
||||
|
||||
Ouput:
|
||||
retrieved_doc_embeds (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs, dim)`
|
||||
The retrieval embeddings of the retrieved docs per query.
|
||||
doc_ids (:obj:`np.ndarray` of shape :obj:`batch_size, n_docs`)
|
||||
The ids of the documents in the index
|
||||
doc_dicts (:obj:`List[dict]`):
|
||||
The retrieved_doc_embeds examples per query.
|
||||
"""
|
||||
|
||||
# single GPU training
|
||||
if not dist.is_initialized():
|
||||
doc_ids, retrieved_doc_embeds = self._main_retrieve(question_hidden_states, n_docs)
|
||||
return retrieved_doc_embeds, doc_ids, self.index.get_doc_dicts(doc_ids)
|
||||
|
||||
# distributed training
|
||||
world_size = dist.get_world_size(group=self.process_group)
|
||||
|
||||
# gather logic
|
||||
gather_list = None
|
||||
if self._is_main():
|
||||
gather_list = [torch.empty(question_hidden_states.shape, dtype=torch.float32) for _ in range(world_size)]
|
||||
dist.gather(torch.tensor(question_hidden_states), dst=0, gather_list=gather_list, group=self.process_group)
|
||||
|
||||
# scatter logic
|
||||
n_queries = question_hidden_states.shape[0]
|
||||
scatter_ids = []
|
||||
scatter_vectors = []
|
||||
if self._is_main():
|
||||
assert len(gather_list) == world_size
|
||||
ids, vectors = self._main_retrieve(torch.cat(gather_list).numpy(), n_docs)
|
||||
ids, vectors = torch.tensor(ids), torch.tensor(vectors)
|
||||
scatter_ids = self._chunk_tensor(ids, n_queries)
|
||||
scatter_vectors = self._chunk_tensor(vectors, n_queries)
|
||||
doc_ids = self._scattered(scatter_ids, [n_queries, n_docs], target_type=torch.int64)
|
||||
retrieved_doc_embeds = self._scattered(scatter_vectors, [n_queries, n_docs, question_hidden_states.shape[1]])
|
||||
|
||||
return retrieved_doc_embeds.numpy(), doc_ids.numpy(), self.index.get_doc_dicts(doc_ids)
|
|
@ -0,0 +1,310 @@
|
|||
""" Evaluation script for RAG models."""
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import BartForConditionalGeneration, RagRetriever, RagSequenceForGeneration, RagTokenForGeneration
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
|
||||
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # isort:skip
|
||||
from examples.rag.utils import exact_match_score, f1_score # noqa: E402 # isort:skip
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
transformers_logging.set_verbosity_info()
|
||||
|
||||
|
||||
def infer_model_type(model_name_or_path):
|
||||
if "token" in model_name_or_path:
|
||||
return "rag_token"
|
||||
if "sequence" in model_name_or_path:
|
||||
return "rag_sequence"
|
||||
if "bart" in model_name_or_path:
|
||||
return "bart"
|
||||
return None
|
||||
|
||||
|
||||
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
||||
return max(metric_fn(prediction, gt) for gt in ground_truths)
|
||||
|
||||
|
||||
def get_scores(args, preds_path, gold_data_path):
|
||||
hypos = [line.strip() for line in open(preds_path, "r").readlines()]
|
||||
answers = []
|
||||
|
||||
if args.gold_data_mode == "qa":
|
||||
data = pd.read_csv(gold_data_path, sep="\t", header=None)
|
||||
for answer_list in data[1]:
|
||||
ground_truths = ast.literal_eval(answer_list)
|
||||
answers.append(ground_truths)
|
||||
else:
|
||||
references = [line.strip() for line in open(gold_data_path, "r").readlines()]
|
||||
answers = [[reference] for reference in references]
|
||||
|
||||
f1 = em = total = 0
|
||||
for prediction, ground_truths in zip(hypos, answers):
|
||||
total += 1
|
||||
em += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths)
|
||||
f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths)
|
||||
|
||||
em = 100.0 * em / total
|
||||
f1 = 100.0 * f1 / total
|
||||
|
||||
logger.info(f"F1: {f1:.2f}")
|
||||
logger.info(f"EM: {em:.2f}")
|
||||
|
||||
|
||||
def get_precision_at_k(args, preds_path, gold_data_path):
|
||||
k = args.k
|
||||
hypos = [line.strip() for line in open(preds_path, "r").readlines()]
|
||||
references = [line.strip() for line in open(gold_data_path, "r").readlines()]
|
||||
|
||||
em = total = 0
|
||||
for hypo, reference in zip(hypos, references):
|
||||
hypo_provenance = set(hypo.split("\t")[:k])
|
||||
ref_provenance = set(reference.split("\t")[1 : (k + 1)])
|
||||
total += 1
|
||||
em += len(hypo_provenance & ref_provenance) / k
|
||||
|
||||
em = 100.0 * em / total
|
||||
logger.info(f"Precision@{k}: {em: .2f}")
|
||||
|
||||
|
||||
def evaluate_batch_retrieval(args, rag_model, questions):
|
||||
def strip_title(title):
|
||||
if title.startswith('"'):
|
||||
title = title[1:]
|
||||
if title.endswith('"'):
|
||||
title = title[:-1]
|
||||
return title
|
||||
|
||||
retriever_input_ids = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus(
|
||||
questions,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
)["input_ids"].to(args.device)
|
||||
|
||||
question_enc_outputs = rag_model.rag.question_encoder(retriever_input_ids, return_dict=True)
|
||||
question_enc_pool_output = question_enc_outputs.pooler_output
|
||||
|
||||
result = rag_model.retriever(
|
||||
retriever_input_ids,
|
||||
question_enc_pool_output.cpu().detach().to(torch.float32).numpy(),
|
||||
prefix=rag_model.rag.generator.config.prefix,
|
||||
n_docs=rag_model.config.n_docs,
|
||||
return_tensors="pt",
|
||||
)
|
||||
all_docs = rag_model.retriever.index.get_doc_dicts(result.doc_ids)
|
||||
provenance_strings = []
|
||||
for docs in all_docs:
|
||||
provenance = [strip_title(title) for title in docs["title"]]
|
||||
provenance_strings.append("\t".join(provenance))
|
||||
return provenance_strings
|
||||
|
||||
|
||||
def evaluate_batch_e2e(args, rag_model, questions):
|
||||
with torch.no_grad():
|
||||
input_ids = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus(
|
||||
questions, return_tensors="pt", padding=True, truncation=True
|
||||
)["input_ids"].to(args.device)
|
||||
outputs = rag_model.generate( # rag_model overwrites generate
|
||||
input_ids,
|
||||
num_beams=args.num_beams,
|
||||
min_length=args.min_length,
|
||||
max_length=args.max_length,
|
||||
early_stopping=False,
|
||||
num_return_sequences=1,
|
||||
bad_words_ids=[[0, 0]], # BART likes to repeat BOS tokens, dont allow it to generate more than one
|
||||
clean_up_tokenization=True,
|
||||
print_docs=args.print_docs,
|
||||
)
|
||||
answers = rag_model.retriever.generator_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
if args.print_predictions:
|
||||
for q, a in zip(questions, answers):
|
||||
logger.info("Q: {} - A: {}".format(q, a))
|
||||
|
||||
return answers
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
choices=["rag_sequence", "rag_token", "bart"],
|
||||
type=str,
|
||||
help="RAG model type: rag_sequence, rag_token or bart, if none specified, the type is inferred from the model_name_or_path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index_name",
|
||||
default=None,
|
||||
choices=["hf", "legacy"],
|
||||
type=str,
|
||||
help="RAG model retriever type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index_path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Path to the retrieval index",
|
||||
)
|
||||
parser.add_argument("--n_docs", default=5, type=int, help="Number of retrieved docs")
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pretrained checkpoints or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_mode",
|
||||
choices=["e2e", "retrieval"],
|
||||
default="e2e",
|
||||
type=str,
|
||||
help="Evaluation mode, e2e calculates exact match and F1 of the downstream task, retrieval calulates precision@k.",
|
||||
)
|
||||
parser.add_argument("--k", default=1, type=int, help="k for the precision@k calculation")
|
||||
parser.add_argument(
|
||||
"--evaluation_set",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to a file containing evaluation samples",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gold_data_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to a tab-separated file with gold samples",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gold_data_mode",
|
||||
default="qa",
|
||||
type=str,
|
||||
choices=["qa", "ans"],
|
||||
help="Format of the gold data file"
|
||||
"qa - a single line in the following format: question [tab] answer_list"
|
||||
"ans - a single line of the gold file contains the expected answer string",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--predictions_path",
|
||||
type=str,
|
||||
default="predictions.txt",
|
||||
help="Name of the predictions file, to be stored in the checkpoints directry",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_all_checkpoints",
|
||||
action="store_true",
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_batch_size",
|
||||
default=8,
|
||||
type=int,
|
||||
help="Batch size per GPU/CPU for evaluation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recalculate",
|
||||
help="Recalculate predictions even if the prediction file exists",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_beams",
|
||||
default=4,
|
||||
type=int,
|
||||
help="Number of beams to be used when generating answers",
|
||||
)
|
||||
parser.add_argument("--min_length", default=1, type=int, help="Min length of the generated answers")
|
||||
parser.add_argument("--max_length", default=50, type=int, help="Max length of the generated answers")
|
||||
|
||||
parser.add_argument(
|
||||
"--print_predictions",
|
||||
action="store_true",
|
||||
help="If True, prints predictions while evaluating.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--print_docs",
|
||||
action="store_true",
|
||||
help="If True, prints docs retried while generating.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
return args
|
||||
|
||||
|
||||
def main(args):
|
||||
model_kwargs = {}
|
||||
if args.model_type is None:
|
||||
args.model_type = infer_model_type(args.model_name_or_path)
|
||||
assert args.model_type is not None
|
||||
if args.model_type.startswith("rag"):
|
||||
model_class = RagTokenForGeneration if args.model_type == "rag_token" else RagSequenceForGeneration
|
||||
model_kwargs["n_docs"] = args.n_docs
|
||||
if args.index_name is not None:
|
||||
model_kwargs["index_name"] = args.index_name
|
||||
if args.index_path is not None:
|
||||
model_kwargs["index_path"] = args.index_path
|
||||
else:
|
||||
model_class = BartForConditionalGeneration
|
||||
|
||||
checkpoints = (
|
||||
[f.path for f in os.scandir(args.model_name_or_path) if f.is_dir()]
|
||||
if args.eval_all_checkpoints
|
||||
else [args.model_name_or_path]
|
||||
)
|
||||
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
|
||||
score_fn = get_scores if args.eval_mode == "e2e" else get_precision_at_k
|
||||
evaluate_batch_fn = evaluate_batch_e2e if args.eval_mode == "e2e" else evaluate_batch_retrieval
|
||||
|
||||
for checkpoint in checkpoints:
|
||||
if os.path.exists(args.predictions_path) and (not args.recalculate):
|
||||
logger.info("Calculating metrics based on an existing predictions file: {}".format(args.predictions_path))
|
||||
score_fn(args, args.predictions_path, args.gold_data_path)
|
||||
continue
|
||||
|
||||
logger.info("***** Running evaluation for {} *****".format(checkpoint))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
logger.info(" Predictions will be stored under {}".format(args.predictions_path))
|
||||
|
||||
if args.model_type.startswith("rag"):
|
||||
retriever = RagRetriever.from_pretrained(checkpoint, **model_kwargs)
|
||||
model = model_class.from_pretrained(checkpoint, retriever=retriever, **model_kwargs)
|
||||
model.retriever.init_retrieval()
|
||||
else:
|
||||
model = model_class.from_pretrained(checkpoint, **model_kwargs)
|
||||
model.to(args.device)
|
||||
|
||||
with open(args.evaluation_set, "r") as eval_file, open(args.predictions_path, "w") as preds_file:
|
||||
questions = []
|
||||
for line in tqdm(eval_file):
|
||||
questions.append(line.strip())
|
||||
if len(questions) == args.eval_batch_size:
|
||||
answers = evaluate_batch_fn(args, model, questions)
|
||||
preds_file.write("\n".join(answers) + "\n")
|
||||
preds_file.flush()
|
||||
questions = []
|
||||
if len(questions) > 0:
|
||||
answers = evaluate_batch_fn(args, model, questions)
|
||||
preds_file.write("\n".join(answers))
|
||||
preds_file.flush()
|
||||
|
||||
score_fn(args, args.predictions_path, args.gold_data_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
main(args)
|
|
@ -0,0 +1,474 @@
|
|||
"""Finetuning script for RAG models. Adapted from examples.seq2seq.finetune.py"""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
BartForConditionalGeneration,
|
||||
RagConfig,
|
||||
RagSequenceForGeneration,
|
||||
RagTokenForGeneration,
|
||||
RagTokenizer,
|
||||
T5ForConditionalGeneration,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
|
||||
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip
|
||||
|
||||
from examples.lightning_base import BaseTransformer, add_generic_args, generic_train # noqa: E402 # isort:skip
|
||||
from examples.rag.callbacks import get_checkpoint_callback # noqa: E402 # isort:skip
|
||||
from examples.rag.distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
||||
from examples.rag.utils import ( # noqa: E402 # isort:skip
|
||||
Seq2SeqDataset,
|
||||
calculate_exact_match,
|
||||
is_rag_model,
|
||||
set_extra_model_params,
|
||||
)
|
||||
from examples.seq2seq.callbacks import Seq2SeqLoggingCallback, get_early_stopping_callback # noqa: E402 # isort:skip
|
||||
from examples.seq2seq.utils import ( # noqa: E402 # isort:skip
|
||||
flatten_list,
|
||||
get_git_info,
|
||||
lmap,
|
||||
pickle_save,
|
||||
save_git_info,
|
||||
save_json,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
transformers_logging.set_verbosity_info()
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(AttrDict, self).__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
|
||||
class GenerativeQAModule(BaseTransformer):
|
||||
mode = "generative_qa"
|
||||
loss_names = ["loss"]
|
||||
metric_names = ["em"]
|
||||
val_metric = "em"
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
# when loading from a pytorch lightning checkpoint, hparams are passed as dict
|
||||
if isinstance(hparams, dict):
|
||||
hparams = AttrDict(hparams)
|
||||
if hparams.model_type == "rag_sequence":
|
||||
self.model_class = RagSequenceForGeneration
|
||||
elif hparams.model_type == "rag_token":
|
||||
self.model_class = RagTokenForGeneration
|
||||
elif hparams.model_type == "bart":
|
||||
self.model_class = BartForConditionalGeneration
|
||||
else:
|
||||
self.model_class = T5ForConditionalGeneration
|
||||
self.is_rag_model = is_rag_model(hparams.model_type)
|
||||
|
||||
config_class = RagConfig if self.is_rag_model else AutoConfig
|
||||
config = config_class.from_pretrained(hparams.model_name_or_path)
|
||||
|
||||
# set extra_model_params for generator configs and load_model
|
||||
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "attention_dropout", "dropout")
|
||||
if self.is_rag_model:
|
||||
if args.prefix is not None:
|
||||
config.generator.prefix = args.prefix
|
||||
config.label_smoothing = hparams.label_smoothing
|
||||
hparams, config.generator = set_extra_model_params(extra_model_params, hparams, config.generator)
|
||||
retriever = RagPyTorchDistributedRetriever.from_pretrained(hparams.model_name_or_path)
|
||||
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config, retriever=retriever)
|
||||
prefix = config.question_encoder.prefix
|
||||
else:
|
||||
if args.prefix is not None:
|
||||
config.prefix = args.prefix
|
||||
hparams, config = set_extra_model_params(extra_model_params, hparams, config)
|
||||
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config)
|
||||
prefix = config.prefix
|
||||
|
||||
tokenizer = (
|
||||
RagTokenizer.from_pretrained(hparams.model_name_or_path)
|
||||
if self.is_rag_model
|
||||
else AutoTokenizer.from_pretrained(hparams.model_name_or_path)
|
||||
)
|
||||
|
||||
super().__init__(hparams, config=config, tokenizer=tokenizer, model=model)
|
||||
|
||||
save_git_info(self.hparams.output_dir)
|
||||
self.output_dir = Path(self.hparams.output_dir)
|
||||
self.metrics_save_path = Path(self.output_dir) / "metrics.json"
|
||||
self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
|
||||
pickle_save(self.hparams, self.hparams_save_path)
|
||||
self.step_count = 0
|
||||
self.metrics = defaultdict(list)
|
||||
|
||||
self.dataset_kwargs: dict = dict(
|
||||
data_dir=self.hparams.data_dir,
|
||||
max_source_length=self.hparams.max_source_length,
|
||||
prefix=prefix or "",
|
||||
)
|
||||
n_observations_per_split = {
|
||||
"train": self.hparams.n_train,
|
||||
"val": self.hparams.n_val,
|
||||
"test": self.hparams.n_test,
|
||||
}
|
||||
self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}
|
||||
|
||||
self.target_lens = {
|
||||
"train": self.hparams.max_target_length,
|
||||
"val": self.hparams.val_max_target_length,
|
||||
"test": self.hparams.test_max_target_length,
|
||||
}
|
||||
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
|
||||
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
|
||||
|
||||
self.hparams.git_sha = get_git_info()["repo_sha"]
|
||||
self.num_workers = hparams.num_workers
|
||||
self.distributed_port = self.hparams.distributed_port
|
||||
|
||||
def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True):
|
||||
logger.info("Custom init_ddp_connection.")
|
||||
os.environ["MASTER_PORT"] = str(self.distributed_port)
|
||||
super().init_ddp_connection(global_rank, world_size, is_slurm_managing_tasks)
|
||||
if self.is_rag_model:
|
||||
self.model.retriever.init_retrieval(self.distributed_port)
|
||||
|
||||
def forward(self, input_ids, **kwargs):
|
||||
return self.model(input_ids, **kwargs)
|
||||
|
||||
def ids_to_clean_text(self, generated_ids: List[int]):
|
||||
gen_text = self.tokenizer.batch_decode(
|
||||
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
)
|
||||
return lmap(str.strip, gen_text)
|
||||
|
||||
def _step(self, batch: dict) -> Tuple:
|
||||
source_ids, source_mask, target_ids = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
|
||||
|
||||
rag_kwargs = {}
|
||||
if isinstance(self.model, T5ForConditionalGeneration):
|
||||
decoder_input_ids = self.model._shift_right(target_ids)
|
||||
lm_labels = target_ids
|
||||
elif isinstance(self.model, BartForConditionalGeneration):
|
||||
decoder_input_ids = target_ids[:, :-1].contiguous()
|
||||
lm_labels = target_ids[:, 1:].clone()
|
||||
else:
|
||||
assert self.is_rag_model
|
||||
generator = self.model.rag.generator
|
||||
if isinstance(generator, T5ForConditionalGeneration):
|
||||
decoder_start_token_id = generator.config.decoder_start_token_id
|
||||
decoder_input_ids = (
|
||||
torch.cat(
|
||||
[torch.Tensor([[decoder_start_token_id]] * target_ids.shape[0]).to(target_ids), target_ids],
|
||||
dim=1,
|
||||
)
|
||||
if target_ids.shape[0] < self.target_lens["train"]
|
||||
else generator._shift_right(target_ids)
|
||||
)
|
||||
elif isinstance(generator, BartForConditionalGeneration):
|
||||
decoder_input_ids = target_ids
|
||||
lm_labels = decoder_input_ids
|
||||
rag_kwargs["reduce_loss"] = True
|
||||
|
||||
assert decoder_input_ids is not None
|
||||
|
||||
outputs = self(
|
||||
source_ids,
|
||||
attention_mask=source_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
use_cache=False,
|
||||
labels=lm_labels,
|
||||
return_dict=True,
|
||||
**rag_kwargs,
|
||||
)
|
||||
|
||||
loss = outputs["loss"]
|
||||
return (loss,)
|
||||
|
||||
@property
|
||||
def pad(self) -> int:
|
||||
raise NotImplementedError("pad not implemented")
|
||||
|
||||
def training_step(self, batch, batch_idx) -> Dict:
|
||||
loss_tensors = self._step(batch)
|
||||
|
||||
logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
||||
# tokens per batch
|
||||
tgt_pad_token_id = (
|
||||
self.tokenizer.generator.pad_token_id
|
||||
if isinstance(self.tokenizer, RagTokenizer)
|
||||
else self.tokenizer.pad_token_id
|
||||
)
|
||||
src_pad_token_id = (
|
||||
self.tokenizer.question_encoder.pad_token_id
|
||||
if isinstance(self.tokenizer, RagTokenizer)
|
||||
else self.tokenizer.pad_token_id
|
||||
)
|
||||
logs["tpb"] = (
|
||||
batch["input_ids"].ne(src_pad_token_id).sum() + batch["decoder_input_ids"].ne(tgt_pad_token_id).sum()
|
||||
)
|
||||
|
||||
return {"loss": loss_tensors[0], "log": logs}
|
||||
|
||||
def validation_step(self, batch, batch_idx) -> Dict:
|
||||
return self._generative_step(batch)
|
||||
|
||||
def validation_epoch_end(self, outputs, prefix="val") -> Dict:
|
||||
self.step_count += 1
|
||||
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
|
||||
loss = losses["loss"]
|
||||
gen_metrics = {
|
||||
k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]
|
||||
}
|
||||
metrics_tensor: torch.FloatTensor = torch.tensor(gen_metrics[self.val_metric]).type_as(loss)
|
||||
gen_metrics.update({k: v.item() for k, v in losses.items()})
|
||||
|
||||
# fix for https://github.com/PyTorchLightning/pytorch-lightning/issues/2424
|
||||
if dist.is_initialized():
|
||||
dist.all_reduce(metrics_tensor, op=dist.ReduceOp.SUM)
|
||||
metrics_tensor = metrics_tensor / dist.get_world_size()
|
||||
gen_metrics.update({self.val_metric: metrics_tensor.item()})
|
||||
|
||||
losses.update(gen_metrics)
|
||||
metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
|
||||
metrics["step_count"] = self.step_count
|
||||
self.save_metrics(metrics, prefix) # writes to self.metrics_save_path
|
||||
preds = flatten_list([x["preds"] for x in outputs])
|
||||
return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_{self.val_metric}": metrics_tensor}
|
||||
|
||||
def save_metrics(self, latest_metrics, type_path) -> None:
|
||||
self.metrics[type_path].append(latest_metrics)
|
||||
save_json(self.metrics, self.metrics_save_path)
|
||||
|
||||
def calc_generative_metrics(self, preds, target) -> Dict:
|
||||
return calculate_exact_match(preds, target)
|
||||
|
||||
def _generative_step(self, batch: dict) -> dict:
|
||||
start_time = time.time()
|
||||
generated_ids = self.model.generate(
|
||||
batch["input_ids"],
|
||||
do_deduplication=False, # rag specific parameter
|
||||
use_cache=True,
|
||||
min_length=1,
|
||||
max_length=self.target_lens["val"],
|
||||
)
|
||||
|
||||
gen_time = (time.time() - start_time) / batch["input_ids"].shape[0]
|
||||
preds: List[str] = self.ids_to_clean_text(generated_ids)
|
||||
target: List[str] = self.ids_to_clean_text(batch["decoder_input_ids"])
|
||||
loss_tensors = self._step(batch)
|
||||
base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
||||
gen_metrics: Dict = self.calc_generative_metrics(preds, target)
|
||||
|
||||
summ_len = np.mean(lmap(len, generated_ids))
|
||||
base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **gen_metrics)
|
||||
return base_metrics
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
return self._generative_step(batch)
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
return self.validation_epoch_end(outputs, prefix="test")
|
||||
|
||||
def get_dataset(self, type_path) -> Seq2SeqDataset:
|
||||
n_obs = self.n_obs[type_path]
|
||||
max_target_length = self.target_lens[type_path]
|
||||
dataset = Seq2SeqDataset(
|
||||
self.tokenizer,
|
||||
type_path=type_path,
|
||||
n_obs=n_obs,
|
||||
max_target_length=max_target_length,
|
||||
**self.dataset_kwargs,
|
||||
)
|
||||
return dataset
|
||||
|
||||
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
|
||||
dataset = self.get_dataset(type_path)
|
||||
sampler = None
|
||||
if self.hparams.sortish_sampler and type_path == "train":
|
||||
assert self.hparams.gpus <= 1 # TODO: assert earlier
|
||||
sampler = dataset.make_sortish_sampler(batch_size)
|
||||
shuffle = False
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=dataset.collate_fn,
|
||||
shuffle=shuffle,
|
||||
num_workers=self.num_workers,
|
||||
sampler=sampler,
|
||||
)
|
||||
return dataloader
|
||||
|
||||
def train_dataloader(self) -> DataLoader:
|
||||
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
|
||||
t_total = (
|
||||
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
|
||||
// self.hparams.accumulate_grad_batches
|
||||
* float(self.hparams.max_epochs)
|
||||
)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
if max(scheduler.get_last_lr()) > 0:
|
||||
warnings.warn("All learning rates are 0")
|
||||
self.lr_scheduler = scheduler
|
||||
return dataloader
|
||||
|
||||
def val_dataloader(self) -> DataLoader:
|
||||
return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)
|
||||
|
||||
def test_dataloader(self) -> DataLoader:
|
||||
return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size)
|
||||
|
||||
@pl.utilities.rank_zero_only
|
||||
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||
save_path = self.output_dir.joinpath("checkpoint{}".format(self.step_count))
|
||||
self.model.config.save_step = self.step_count
|
||||
self.model.save_pretrained(save_path)
|
||||
self.tokenizer.save_pretrained(save_path)
|
||||
|
||||
@staticmethod
|
||||
def add_model_specific_args(parser, root_dir):
|
||||
BaseTransformer.add_model_specific_args(parser, root_dir)
|
||||
add_generic_args(parser, root_dir)
|
||||
parser.add_argument(
|
||||
"--max_source_length",
|
||||
default=128,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_target_length",
|
||||
default=25,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--val_max_target_length",
|
||||
default=25,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test_max_target_length",
|
||||
default=25,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument("--sortish_sampler", action="store_true", default=False)
|
||||
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
|
||||
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||
parser.add_argument("--n_val", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||
parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||
parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
|
||||
parser.add_argument(
|
||||
"--prefix",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Prefix added at the beginning of each text, typically used with T5-based models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--early_stopping_patience",
|
||||
type=int,
|
||||
default=-1,
|
||||
required=False,
|
||||
help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--distributed-port", type=int, default=-1, required=False, help="Port number for distributed training."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
choices=["rag_sequence", "rag_token", "bart", "t5"],
|
||||
type=str,
|
||||
help="RAG model type: sequence or token, if none specified, the type is inferred from the model_name_or_path",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main(args, model=None) -> GenerativeQAModule:
|
||||
Path(args.output_dir).mkdir(exist_ok=True)
|
||||
if model is None:
|
||||
model: GenerativeQAModule = GenerativeQAModule(args)
|
||||
|
||||
dataset = Path(args.data_dir).name
|
||||
if (
|
||||
args.logger_name == "default"
|
||||
or args.fast_dev_run
|
||||
or str(args.output_dir).startswith("/tmp")
|
||||
or str(args.output_dir).startswith("/var")
|
||||
):
|
||||
logger = True # don't pollute wandb logs unnecessarily
|
||||
elif args.logger_name == "wandb":
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
project = os.environ.get("WANDB_PROJECT", dataset)
|
||||
logger = WandbLogger(name=model.output_dir.name, project=project)
|
||||
|
||||
elif args.logger_name == "wandb_shared":
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
|
||||
|
||||
es_callback = (
|
||||
get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
|
||||
if args.early_stopping_patience >= 0
|
||||
else False
|
||||
)
|
||||
trainer: pl.Trainer = generic_train(
|
||||
model,
|
||||
args,
|
||||
logging_callback=Seq2SeqLoggingCallback(),
|
||||
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
|
||||
early_stopping_callback=es_callback,
|
||||
logger=logger,
|
||||
)
|
||||
pickle_save(model.hparams, model.output_dir / "hparams.pkl")
|
||||
|
||||
if not args.do_predict:
|
||||
return model
|
||||
|
||||
model.hparams.test_checkpoint = ""
|
||||
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
|
||||
if checkpoints:
|
||||
model.hparams.test_checkpoint = checkpoints[-1]
|
||||
trainer.resume_from_checkpoint = checkpoints[-1] # best checkpoint
|
||||
trainer.logger.log_hyperparams(model.hparams)
|
||||
|
||||
# test() without a model tests using the best checkpoint automatically
|
||||
trainer.test()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
|
@ -0,0 +1,34 @@
|
|||
# Add parent directory to python path to access lightning_base.py
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
|
||||
# A sample finetuning run, you need to specify data_dir, output_dir and model_name_or_path
|
||||
# run ./examples/rag/finetune.sh --help to see all the possible options
|
||||
|
||||
python examples/rag/finetune.py \
|
||||
--data_dir $DATA_DIR \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--model_name_or_path $MODLE_NAME_OR_PATH \
|
||||
--model_type rag_sequence \
|
||||
--fp16 \
|
||||
--gpus 8 \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--n_val -1 \
|
||||
--val_check_interval 0.25 \
|
||||
--train_batch_size 8 \
|
||||
--eval_batch_size 1 \
|
||||
--max_source_length 128 \
|
||||
--max_target_length 25 \
|
||||
--val_max_target_length 25 \
|
||||
--test_max_target_length 25 \
|
||||
--label_smoothing 0.1 \
|
||||
--dropout 0.1 \
|
||||
--attention_dropout 0.1 \
|
||||
--weight_decay 0.001 \
|
||||
--adam_epsilon 1e-08 \
|
||||
--max_grad_norm 0.1 \
|
||||
--lr_scheduler polynomial \
|
||||
--learning_rate 3e-05 \
|
||||
--num_train_epochs 100 \
|
||||
--warmup_steps 500 \
|
||||
--gradient_accumulation_steps 1
|
|
@ -0,0 +1,47 @@
|
|||
"""
|
||||
This script reads DPR retriever training data and parses each datapoint. We save a line per datapoint.
|
||||
Each line consists of the query followed by a tab-separated list of Wikipedia page titles constituting
|
||||
positive contexts for a given query.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--src_path",
|
||||
type=str,
|
||||
default="biencoder-nq-dev.json",
|
||||
help="Path to raw DPR training data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--evaluation_set",
|
||||
type=str,
|
||||
help="where to store parsed evaluation_set file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gold_data_path",
|
||||
type=str,
|
||||
help="where to store parsed gold_data_path file",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.src_path, "r") as src_file, open(args.evaluation_set, "w") as eval_file, open(
|
||||
args.gold_data_path, "w"
|
||||
) as gold_file:
|
||||
dpr_records = json.load(src_file)
|
||||
for dpr_record in tqdm(dpr_records):
|
||||
question = dpr_record["question"]
|
||||
contexts = [context["title"] for context in dpr_record["positive_ctxs"]]
|
||||
eval_file.write(question + "\n")
|
||||
gold_file.write("\t".join(contexts) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,4 @@
|
|||
faiss-cpu >= 1.6.3
|
||||
datasets >= 1.0.1
|
||||
psutil >= 5.7.0
|
||||
torch >= 1.4.0
|
|
@ -0,0 +1,156 @@
|
|||
import json
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
from datasets import Dataset
|
||||
|
||||
import faiss
|
||||
from transformers.configuration_bart import BartConfig
|
||||
from transformers.configuration_dpr import DPRConfig
|
||||
from transformers.configuration_rag import RagConfig
|
||||
from transformers.file_utils import is_datasets_available, is_faiss_available, is_psutil_available, is_torch_available
|
||||
from transformers.tokenization_bart import BartTokenizer
|
||||
from transformers.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
|
||||
from transformers.tokenization_dpr import DPRQuestionEncoderTokenizer
|
||||
from transformers.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
|
||||
|
||||
|
||||
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip
|
||||
|
||||
from examples.rag.distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
||||
|
||||
|
||||
def require_distributed_retrieval(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a set of dependencies necessary for pefrorm retrieval with
|
||||
:class:`~transformers.RagRetriever`.
|
||||
|
||||
These tests are skipped when respective libraries are not installed.
|
||||
|
||||
"""
|
||||
if not (is_torch_available() and is_datasets_available() and is_faiss_available() and is_psutil_available()):
|
||||
test_case = unittest.skip("test requires PyTorch, Datasets, Faiss, psutil")(test_case)
|
||||
return test_case
|
||||
|
||||
|
||||
@require_distributed_retrieval
|
||||
class RagRetrieverTest(TestCase):
|
||||
def setUp(self):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
self.retrieval_vector_size = 8
|
||||
|
||||
# DPR tok
|
||||
vocab_tokens = [
|
||||
"[UNK]",
|
||||
"[CLS]",
|
||||
"[SEP]",
|
||||
"[PAD]",
|
||||
"[MASK]",
|
||||
"want",
|
||||
"##want",
|
||||
"##ed",
|
||||
"wa",
|
||||
"un",
|
||||
"runn",
|
||||
"##ing",
|
||||
",",
|
||||
"low",
|
||||
"lowest",
|
||||
]
|
||||
dpr_tokenizer_path = os.path.join(self.tmpdirname, "dpr_tokenizer")
|
||||
os.makedirs(dpr_tokenizer_path, exist_ok=True)
|
||||
self.vocab_file = os.path.join(dpr_tokenizer_path, DPR_VOCAB_FILES_NAMES["vocab_file"])
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
|
||||
# BART tok
|
||||
vocab = [
|
||||
"l",
|
||||
"o",
|
||||
"w",
|
||||
"e",
|
||||
"r",
|
||||
"s",
|
||||
"t",
|
||||
"i",
|
||||
"d",
|
||||
"n",
|
||||
"\u0120",
|
||||
"\u0120l",
|
||||
"\u0120n",
|
||||
"\u0120lo",
|
||||
"\u0120low",
|
||||
"er",
|
||||
"\u0120lowest",
|
||||
"\u0120newer",
|
||||
"\u0120wider",
|
||||
"<unk>",
|
||||
]
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
|
||||
self.special_tokens_map = {"unk_token": "<unk>"}
|
||||
|
||||
bart_tokenizer_path = os.path.join(self.tmpdirname, "bart_tokenizer")
|
||||
os.makedirs(bart_tokenizer_path, exist_ok=True)
|
||||
self.vocab_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["vocab_file"])
|
||||
self.merges_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["merges_file"])
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as fp:
|
||||
fp.write(json.dumps(vocab_tokens) + "\n")
|
||||
with open(self.merges_file, "w", encoding="utf-8") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
def get_dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer:
|
||||
return DPRQuestionEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer"))
|
||||
|
||||
def get_bart_tokenizer(self) -> BartTokenizer:
|
||||
return BartTokenizer.from_pretrained(os.path.join(self.tmpdirname, "bart_tokenizer"))
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def get_dummy_pytorch_distributed_retriever(self, init_retrieval, port=12345) -> RagPyTorchDistributedRetriever:
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"id": ["0", "1"],
|
||||
"text": ["foo", "bar"],
|
||||
"title": ["Foo", "Bar"],
|
||||
"embeddings": [np.ones(self.retrieval_vector_size), 2 * np.ones(self.retrieval_vector_size)],
|
||||
}
|
||||
)
|
||||
dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT)
|
||||
config = RagConfig(
|
||||
retrieval_vector_size=self.retrieval_vector_size,
|
||||
question_encoder=DPRConfig().to_dict(),
|
||||
generator=BartConfig().to_dict(),
|
||||
)
|
||||
with patch("transformers.retrieval_rag.load_dataset") as mock_load_dataset:
|
||||
mock_load_dataset.return_value = dataset
|
||||
retriever = RagPyTorchDistributedRetriever(
|
||||
config,
|
||||
question_encoder_tokenizer=self.get_dpr_tokenizer(),
|
||||
generator_tokenizer=self.get_bart_tokenizer(),
|
||||
)
|
||||
if init_retrieval:
|
||||
retriever.init_retrieval(port)
|
||||
return retriever
|
||||
|
||||
def test_pytorch_distributed_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_pytorch_distributed_retriever(init_retrieval=True)
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
|
||||
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
|
||||
self.assertEqual(len(doc_dicts), 2)
|
||||
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
|
||||
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
|
||||
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
|
||||
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
||||
self.assertListEqual(list(doc_ids), [1, 0])
|
|
@ -0,0 +1,187 @@
|
|||
import linecache
|
||||
import re
|
||||
import string
|
||||
from collections import Counter
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from examples.seq2seq.utils import SortishSampler, trim_batch
|
||||
from transformers import BartTokenizer, RagTokenizer, T5Tokenizer
|
||||
|
||||
|
||||
def encode_line(tokenizer, line, max_length, padding_side, pad_to_max_length=True, return_tensors="pt"):
|
||||
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) and not line.startswith(" ") else {}
|
||||
tokenizer.padding_side = padding_side
|
||||
return tokenizer(
|
||||
[line],
|
||||
max_length=max_length,
|
||||
padding="max_length" if pad_to_max_length else None,
|
||||
truncation=True,
|
||||
return_tensors=return_tensors,
|
||||
add_special_tokens=True,
|
||||
**extra_kw,
|
||||
)
|
||||
|
||||
|
||||
class Seq2SeqDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
data_dir,
|
||||
max_source_length,
|
||||
max_target_length,
|
||||
type_path="train",
|
||||
n_obs=None,
|
||||
src_lang=None,
|
||||
tgt_lang=None,
|
||||
prefix="",
|
||||
):
|
||||
super().__init__()
|
||||
self.src_file = Path(data_dir).joinpath(type_path + ".source")
|
||||
self.tgt_file = Path(data_dir).joinpath(type_path + ".target")
|
||||
self.src_lens = self.get_char_lens(self.src_file)
|
||||
self.max_source_length = max_source_length
|
||||
self.max_target_length = max_target_length
|
||||
assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
|
||||
self.tokenizer = tokenizer
|
||||
self.prefix = prefix
|
||||
if n_obs is not None:
|
||||
self.src_lens = self.src_lens[:n_obs]
|
||||
self.src_lang = src_lang
|
||||
self.tgt_lang = tgt_lang
|
||||
|
||||
def __len__(self):
|
||||
return len(self.src_lens)
|
||||
|
||||
def __getitem__(self, index) -> Dict[str, torch.Tensor]:
|
||||
index = index + 1 # linecache starts at 1
|
||||
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
|
||||
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
|
||||
assert source_line, f"empty source line for index {index}"
|
||||
assert tgt_line, f"empty tgt line for index {index}"
|
||||
|
||||
# Need to add eos token manually for T5
|
||||
if isinstance(self.tokenizer, T5Tokenizer):
|
||||
source_line += self.tokenizer.eos_token
|
||||
tgt_line += self.tokenizer.eos_token
|
||||
|
||||
# Pad source and target to the right
|
||||
source_tokenizer = (
|
||||
self.tokenizer.question_encoder if isinstance(self.tokenizer, RagTokenizer) else self.tokenizer
|
||||
)
|
||||
target_tokenizer = self.tokenizer.generator if isinstance(self.tokenizer, RagTokenizer) else self.tokenizer
|
||||
|
||||
source_inputs = encode_line(source_tokenizer, source_line, self.max_source_length, "right")
|
||||
target_inputs = encode_line(target_tokenizer, tgt_line, self.max_target_length, "right")
|
||||
|
||||
source_ids = source_inputs["input_ids"].squeeze()
|
||||
target_ids = target_inputs["input_ids"].squeeze()
|
||||
src_mask = source_inputs["attention_mask"].squeeze()
|
||||
return {
|
||||
"input_ids": source_ids,
|
||||
"attention_mask": src_mask,
|
||||
"decoder_input_ids": target_ids,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_char_lens(data_file):
|
||||
return [len(x) for x in Path(data_file).open().readlines()]
|
||||
|
||||
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
||||
input_ids = torch.stack([x["input_ids"] for x in batch])
|
||||
masks = torch.stack([x["attention_mask"] for x in batch])
|
||||
target_ids = torch.stack([x["decoder_input_ids"] for x in batch])
|
||||
tgt_pad_token_id = (
|
||||
self.tokenizer.generator.pad_token_id
|
||||
if isinstance(self.tokenizer, RagTokenizer)
|
||||
else self.tokenizer.pad_token_id
|
||||
)
|
||||
src_pad_token_id = (
|
||||
self.tokenizer.question_encoder.pad_token_id
|
||||
if isinstance(self.tokenizer, RagTokenizer)
|
||||
else self.tokenizer.pad_token_id
|
||||
)
|
||||
y = trim_batch(target_ids, tgt_pad_token_id)
|
||||
source_ids, source_mask = trim_batch(input_ids, src_pad_token_id, attention_mask=masks)
|
||||
batch = {
|
||||
"input_ids": source_ids,
|
||||
"attention_mask": source_mask,
|
||||
"decoder_input_ids": y,
|
||||
}
|
||||
return batch
|
||||
|
||||
def make_sortish_sampler(self, batch_size):
|
||||
return SortishSampler(self.src_lens, batch_size)
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def normalize_answer(s):
|
||||
"""Lower text and remove punctuation, articles and extra whitespace."""
|
||||
|
||||
def remove_articles(text):
|
||||
return re.sub(r"\b(a|an|the)\b", " ", text)
|
||||
|
||||
def white_space_fix(text):
|
||||
return " ".join(text.split())
|
||||
|
||||
def remove_punc(text):
|
||||
exclude = set(string.punctuation)
|
||||
return "".join(ch for ch in text if ch not in exclude)
|
||||
|
||||
def lower(text):
|
||||
return text.lower()
|
||||
|
||||
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||
|
||||
|
||||
def f1_score(prediction, ground_truth):
|
||||
prediction_tokens = normalize_answer(prediction).split()
|
||||
ground_truth_tokens = normalize_answer(ground_truth).split()
|
||||
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
|
||||
num_same = sum(common.values())
|
||||
if num_same == 0:
|
||||
return 0
|
||||
precision = 1.0 * num_same / len(prediction_tokens)
|
||||
recall = 1.0 * num_same / len(ground_truth_tokens)
|
||||
f1 = (2 * precision * recall) / (precision + recall)
|
||||
return f1
|
||||
|
||||
|
||||
def exact_match_score(prediction, ground_truth):
|
||||
return normalize_answer(prediction) == normalize_answer(ground_truth)
|
||||
|
||||
|
||||
def calculate_exact_match(output_lns: List[str], reference_lns: List[str]) -> Dict:
|
||||
assert len(output_lns) == len(reference_lns)
|
||||
em = 0
|
||||
for hypo, pred in zip(output_lns, reference_lns):
|
||||
em += exact_match_score(hypo, pred)
|
||||
if len(output_lns) > 0:
|
||||
em /= len(output_lns)
|
||||
return {"em": em}
|
||||
|
||||
|
||||
def is_rag_model(model_prefix):
|
||||
return model_prefix.startswith("rag")
|
||||
|
||||
|
||||
def set_extra_model_params(extra_params, hparams, config):
|
||||
equivalent_param = {p: p for p in extra_params}
|
||||
# T5 models don't have `dropout` param, they have `dropout_rate` instead
|
||||
equivalent_param["dropout"] = "dropout_rate"
|
||||
for p in extra_params:
|
||||
if getattr(hparams, p, None):
|
||||
if not hasattr(config, p) and not hasattr(config, equivalent_param[p]):
|
||||
logger.info("config doesn't have a `{}` attribute".format(p))
|
||||
delattr(hparams, p)
|
||||
continue
|
||||
set_p = p if hasattr(config, p) else equivalent_param[p]
|
||||
setattr(config, set_p, getattr(hparams, p))
|
||||
delattr(hparams, p)
|
||||
return hparams, config
|
|
@ -8,7 +8,7 @@ tensorflow_datasets
|
|||
pytorch-lightning==0.8.5
|
||||
matplotlib
|
||||
git-python==1.0.3
|
||||
faiss
|
||||
faiss-cpu
|
||||
streamlit
|
||||
elasticsearch
|
||||
pandas
|
||||
|
|
|
@ -10,7 +10,7 @@ known_third_party =
|
|||
datasets
|
||||
elasticsearch
|
||||
fairseq
|
||||
faiss
|
||||
faiss-cpu
|
||||
fastprogress
|
||||
fire
|
||||
fugashi
|
||||
|
|
3
setup.py
3
setup.py
|
@ -89,7 +89,8 @@ extras["onnxruntime"] = ["onnxruntime>=1.4.0", "onnxruntime-tools>=1.4.2"]
|
|||
extras["serving"] = ["pydantic", "uvicorn", "fastapi", "starlette"]
|
||||
extras["all"] = extras["serving"] + ["tensorflow", "torch"]
|
||||
|
||||
extras["testing"] = ["pytest", "pytest-xdist", "timeout-decorator", "psutil", "parameterized"]
|
||||
extras["retrieval"] = ["faiss-cpu", "datasets"]
|
||||
extras["testing"] = ["pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil"] + extras["retrieval"]
|
||||
# sphinx-rtd-theme==0.5.0 introduced big changes in the style.
|
||||
extras["docs"] = ["recommonmark", "sphinx", "sphinx-markdown-tables", "sphinx-rtd-theme==0.4.3", "sphinx-copybutton"]
|
||||
extras["quality"] = ["black >= 20.8b1", "isort >= 5", "flake8 >= 3.8.3"]
|
||||
|
|
|
@ -42,6 +42,7 @@ from .configuration_mmbt import MMBTConfig
|
|||
from .configuration_mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig
|
||||
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
|
||||
from .configuration_pegasus import PegasusConfig
|
||||
from .configuration_rag import RagConfig
|
||||
from .configuration_reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
|
||||
from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
|
||||
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
|
||||
|
@ -86,6 +87,7 @@ from .file_utils import (
|
|||
cached_path,
|
||||
is_apex_available,
|
||||
is_datasets_available,
|
||||
is_faiss_available,
|
||||
is_psutil_available,
|
||||
is_py3nvml_available,
|
||||
is_tf_available,
|
||||
|
@ -140,6 +142,9 @@ from .pipelines import (
|
|||
pipeline,
|
||||
)
|
||||
|
||||
# Retriever
|
||||
from .retrieval_rag import RagRetriever
|
||||
|
||||
# Tokenizers
|
||||
from .tokenization_albert import AlbertTokenizer
|
||||
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
|
||||
|
@ -172,6 +177,7 @@ from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFas
|
|||
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
|
||||
from .tokenization_pegasus import PegasusTokenizer
|
||||
from .tokenization_phobert import PhobertTokenizer
|
||||
from .tokenization_rag import RagTokenizer
|
||||
from .tokenization_reformer import ReformerTokenizer
|
||||
from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast
|
||||
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
|
||||
|
@ -416,6 +422,7 @@ if is_torch_available():
|
|||
load_tf_weights_in_openai_gpt,
|
||||
)
|
||||
from .modeling_pegasus import PegasusForConditionalGeneration
|
||||
from .modeling_rag import RagModel, RagSequenceForGeneration, RagTokenForGeneration
|
||||
from .modeling_reformer import (
|
||||
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
ReformerAttention,
|
||||
|
|
|
@ -24,6 +24,7 @@ from .configuration_bert_generation import BertGenerationConfig
|
|||
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
|
||||
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
|
||||
from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig
|
||||
from .configuration_dpr import DPR_PRETRAINED_CONFIG_ARCHIVE_MAP, DPRConfig
|
||||
from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig
|
||||
from .configuration_encoder_decoder import EncoderDecoderConfig
|
||||
from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig
|
||||
|
@ -38,6 +39,7 @@ from .configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfi
|
|||
from .configuration_mobilebert import MobileBertConfig
|
||||
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
|
||||
from .configuration_pegasus import PegasusConfig
|
||||
from .configuration_rag import RagConfig
|
||||
from .configuration_reformer import ReformerConfig
|
||||
from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
|
||||
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
|
||||
|
@ -75,6 +77,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
|
|||
FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
DPR_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
]
|
||||
for key, value, in pretrained_map.items()
|
||||
)
|
||||
|
@ -110,7 +113,9 @@ CONFIG_MAPPING = OrderedDict(
|
|||
("encoder-decoder", EncoderDecoderConfig),
|
||||
("funnel", FunnelConfig),
|
||||
("lxmert", LxmertConfig),
|
||||
("dpr", DPRConfig),
|
||||
("layoutlm", LayoutLMConfig),
|
||||
("rag", RagConfig),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -145,6 +150,8 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
|||
("funnel", "Funnel Transformer"),
|
||||
("lxmert", "LXMERT"),
|
||||
("layoutlm", "LayoutLM"),
|
||||
("dpr", "DPR"),
|
||||
("rag", "RAG"),
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
""" DPR model configuration """
|
||||
|
||||
from .configuration_bert import BertConfig
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .utils import logging
|
||||
|
||||
|
||||
|
@ -27,7 +27,7 @@ DPR_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
|||
}
|
||||
|
||||
|
||||
class DPRConfig(BertConfig):
|
||||
class DPRConfig(PretrainedConfig):
|
||||
r"""
|
||||
:class:`~transformers.DPRConfig` is the configuration class to store the configuration of a
|
||||
`DPRModel`.
|
||||
|
@ -36,12 +36,73 @@ class DPRConfig(BertConfig):
|
|||
It is used to instantiate the components of the DPR model.
|
||||
|
||||
Args:
|
||||
projection_dim (:obj:`int`, optional, defaults to 0):
|
||||
vocab_size (:obj:`int`, `optional`, defaults to 30522):
|
||||
Vocabulary size of the DPR model. Defines the different tokens that
|
||||
can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.BertModel`.
|
||||
hidden_size (:obj:`int`, `optional`, defaults to 768):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (:obj:`int`, `optional`, defaults to 12):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
intermediate_size (:obj:`int`, `optional`, defaults to 3072):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler.
|
||||
If string, :obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"` and :obj:`"gelu_new"` are supported.
|
||||
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
|
||||
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
|
||||
The dropout ratio for the attention probabilities.
|
||||
max_position_embeddings (:obj:`int`, `optional`, defaults to 512):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
|
||||
type_vocab_size (:obj:`int`, `optional`, defaults to 2):
|
||||
The vocabulary size of the `token_type_ids` passed into :class:`~transformers.BertModel`.
|
||||
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
|
||||
The epsilon used by the layer normalization layers.
|
||||
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
|
||||
projection_dim (:obj:`int`, `optional`, defaults to 0):
|
||||
Dimension of the projection for the context and question encoders.
|
||||
If it is set to zero (default), then no projection is done.
|
||||
"""
|
||||
model_type = "dpr"
|
||||
|
||||
def __init__(self, projection_dim: int = 0, **kwargs): # projection of the encoders, 0 for no projection
|
||||
super().__init__(**kwargs)
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=30522,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-12,
|
||||
pad_token_id=0,
|
||||
gradient_checkpointing=False,
|
||||
projection_dim: int = 0,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
self.projection_dim = projection_dim
|
||||
|
|
|
@ -0,0 +1,176 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020, The RAG Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" RAG model configuration """
|
||||
|
||||
import copy
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .file_utils import add_start_docstrings
|
||||
|
||||
|
||||
RAG_CONFIG_DOC = r"""
|
||||
:class:`~transformers.RagConfig` stores the configuration of a `RagModel`.
|
||||
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used
|
||||
to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig`
|
||||
for more information.
|
||||
|
||||
Args:
|
||||
title_sep (:obj:`str`, `optional`, defaults to ``" / "``):
|
||||
Separator inserted between the title and the text of the retrieved document when calling :class:`~transformers.RagRetriever`.
|
||||
doc_sep (:obj:`str`, `optional`, defaults to ``" // "``):
|
||||
Separator inserted between the the text of the retrieved document and the original input when calliang :class:`~transformers.RagRetriever`.
|
||||
n_docs (:obj:`int`, `optional`, defaults to 5):
|
||||
Number of documents to retrieve.
|
||||
max_combined_length (:obj:`int`, `optional`, defaults to 300):
|
||||
Max length of contextualized input returned by :meth:`~transformers.RagRetriever.__call__`.
|
||||
retrieval_vector_size (:obj:`int`, `optional`, defaults to 768):
|
||||
Dimensionality of the document embeddings indexed by :class:`~transformers.RagRetriever`.
|
||||
retrieval_batch_size (:obj:`int`, `optional`, defaults to 8):
|
||||
Retrieval batch size, defined as the number of queries issues concurrently to the faiss index excapsulated :class:`~transformers.RagRetriever`.
|
||||
dataset (:obj:`str`, `optional`, defaults to :obj:`"wiki_dpr"`):
|
||||
A datatset identifier of the indexed dataset on HuggingFace AWS bucket (list all available datasets and ids using :obj:`datasets.list_datasets()`).
|
||||
dataset_split (:obj:`str`, `optional`, defaults to :obj:`train`)
|
||||
Which split of the ``dataset`` to load.
|
||||
index_name (:obj:`str`, `optional`, defaults to :obj:`compressed`)
|
||||
The index_name of the index associated with the :obj:`dataset`. One can choose between :obj:`legacy`, :obj:`exact` and :obj:`compressed`.
|
||||
index_path (:obj:`str`, `optional`)
|
||||
The path to the serialized faiss index on disk.
|
||||
passages_path: (:obj:`str`, `optional`):
|
||||
A path to text passages compatible with the faiss index. Required if using :class:`~transformers.retrieval_rag.LegacyIndex`
|
||||
use_dummy_dataset (:obj:`bool`, `optional`, defaults to ``False``)
|
||||
Whether to load a "dummy" variant of the dataset specified by :obj:`dataset`.
|
||||
label_smoothing (:obj:`float`, `optional`, defaults to 0.0):
|
||||
Only relevant if ``return_loss`` is set to :obj:`True`. Controls the ``epsilon`` parameter value for label smoothing in the loss calculation.
|
||||
If set to ``0.0``, no label smoothing is performed.
|
||||
do_marginalize (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If :obj:`True`, the logits are marginalized over all documents
|
||||
by making use of ``torch.nn.functional.log_softmax``.
|
||||
reduce_loss (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If :obj:`True`, the NLL loss is reduced using the ``torch.Tensor.sum`` operation.
|
||||
do_deduplication (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Controls whether we want to deduplicate the generations from different context documents for a given input.
|
||||
Has to be set to :obj:`False` if used while training with distributed backend.
|
||||
exclude_bos_score (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If :obj:`True`, the score of the BOS token is disregarded when computing
|
||||
the loss.
|
||||
output_retrieved(:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If set to ``True``, :obj:`retrieved_doc_embeds`, :obj:`retrieved_doc_ids`, :obj:`context_input_ids` and :obj:`context_attention_mask` are returned. See returned tensors for more detail.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(RAG_CONFIG_DOC)
|
||||
class RagConfig(PretrainedConfig):
|
||||
model_type = "rag"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=None,
|
||||
is_encoder_decoder=True,
|
||||
prefix=None,
|
||||
bos_token_id=None,
|
||||
pad_token_id=None,
|
||||
eos_token_id=None,
|
||||
decoder_start_token_id=None,
|
||||
title_sep=" / ",
|
||||
doc_sep=" // ",
|
||||
n_docs=5,
|
||||
max_combined_length=300,
|
||||
retrieval_vector_size=768,
|
||||
retrieval_batch_size=8,
|
||||
dataset="wiki_dpr",
|
||||
dataset_split="train",
|
||||
index_name="compressed",
|
||||
index_path=None,
|
||||
passages_path=None,
|
||||
use_dummy_dataset=False,
|
||||
reduce_loss=False,
|
||||
label_smoothing=0.0,
|
||||
do_deduplication=True,
|
||||
exclude_bos_score=False,
|
||||
do_marginalize=False,
|
||||
output_retrieved=False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
bos_token_id=bos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
prefix=prefix,
|
||||
vocab_size=vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
assert (
|
||||
"question_encoder" in kwargs and "generator" in kwargs
|
||||
), "Config has to be initialized with question_encoder and generator config"
|
||||
question_encoder_config = kwargs.pop("question_encoder")
|
||||
question_encoder_model_type = question_encoder_config.pop("model_type")
|
||||
decoder_config = kwargs.pop("generator")
|
||||
decoder_model_type = decoder_config.pop("model_type")
|
||||
|
||||
from .configuration_auto import AutoConfig
|
||||
|
||||
self.question_encoder = AutoConfig.for_model(question_encoder_model_type, **question_encoder_config)
|
||||
self.generator = AutoConfig.for_model(decoder_model_type, **decoder_config)
|
||||
|
||||
self.reduce_loss = reduce_loss
|
||||
self.label_smoothing = label_smoothing
|
||||
self.exclude_bos_score = exclude_bos_score
|
||||
self.do_marginalize = do_marginalize
|
||||
|
||||
self.title_sep = title_sep
|
||||
self.doc_sep = doc_sep
|
||||
self.n_docs = n_docs
|
||||
self.max_combined_length = max_combined_length
|
||||
|
||||
self.dataset = dataset
|
||||
self.dataset_split = dataset_split
|
||||
self.index_name = index_name
|
||||
|
||||
self.retrieval_vector_size = retrieval_vector_size
|
||||
self.retrieval_batch_size = retrieval_batch_size
|
||||
self.passages_path = passages_path
|
||||
self.index_path = index_path
|
||||
self.use_dummy_dataset = use_dummy_dataset
|
||||
|
||||
self.output_retrieved = output_retrieved
|
||||
|
||||
self.do_deduplication = do_deduplication
|
||||
|
||||
@classmethod
|
||||
def from_question_encoder_generator_configs(
|
||||
cls, question_encoder_config: PretrainedConfig, generator_config: PretrainedConfig, **kwargs
|
||||
) -> PretrainedConfig:
|
||||
r"""
|
||||
Instantiate a :class:`~transformers.EncoderDecoderConfig` (or a derived class) from a pre-trained encoder model configuration and decoder model configuration.
|
||||
|
||||
Returns:
|
||||
:class:`EncoderDecoderConfig`: An instance of a configuration object
|
||||
"""
|
||||
return cls(question_encoder=question_encoder_config.to_dict(), generator=generator_config.to_dict(), **kwargs)
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Override the default :meth:`~transformers.PretrainedConfig.to_dict`.
|
||||
|
||||
Returns:
|
||||
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
||||
"""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
output["question_encoder"] = self.question_encoder.to_dict()
|
||||
output["generator"] = self.generator.to_dict()
|
||||
output["model_type"] = self.__class__.model_type
|
||||
return output
|
|
@ -69,6 +69,7 @@ try:
|
|||
import datasets # noqa: F401
|
||||
|
||||
_datasets_available = True
|
||||
logger.debug(f"Succesfully imported datasets version {datasets.__version__}")
|
||||
|
||||
except ImportError:
|
||||
_datasets_available = False
|
||||
|
@ -119,6 +120,16 @@ try:
|
|||
except ImportError:
|
||||
_has_apex = False
|
||||
|
||||
|
||||
try:
|
||||
import faiss # noqa: F401
|
||||
|
||||
_faiss_available = True
|
||||
logger.debug(f"Succesfully imported faiss version {faiss.__version__}")
|
||||
except ImportError:
|
||||
_faiss_available = False
|
||||
|
||||
|
||||
default_cache_path = os.path.join(torch_cache_home, "transformers")
|
||||
|
||||
|
||||
|
@ -175,6 +186,10 @@ def is_apex_available():
|
|||
return _has_apex
|
||||
|
||||
|
||||
def is_faiss_available():
|
||||
return _faiss_available
|
||||
|
||||
|
||||
def add_start_docstrings(*docstr):
|
||||
def docstring_decorator(fn):
|
||||
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
|
||||
|
|
|
@ -27,6 +27,7 @@ from .configuration_auto import (
|
|||
CamembertConfig,
|
||||
CTRLConfig,
|
||||
DistilBertConfig,
|
||||
DPRConfig,
|
||||
ElectraConfig,
|
||||
EncoderDecoderConfig,
|
||||
FlaubertConfig,
|
||||
|
@ -97,6 +98,7 @@ from .modeling_distilbert import (
|
|||
DistilBertForTokenClassification,
|
||||
DistilBertModel,
|
||||
)
|
||||
from .modeling_dpr import DPRQuestionEncoder
|
||||
from .modeling_electra import (
|
||||
ElectraForMaskedLM,
|
||||
ElectraForMultipleChoice,
|
||||
|
@ -148,6 +150,11 @@ from .modeling_mobilebert import (
|
|||
)
|
||||
from .modeling_openai import OpenAIGPTLMHeadModel, OpenAIGPTModel
|
||||
from .modeling_pegasus import PegasusForConditionalGeneration
|
||||
from .modeling_rag import ( # noqa: F401 - need to import all RagModels to be in globals() function
|
||||
RagModel,
|
||||
RagSequenceForGeneration,
|
||||
RagTokenForGeneration,
|
||||
)
|
||||
from .modeling_reformer import (
|
||||
ReformerForMaskedLM,
|
||||
ReformerForQuestionAnswering,
|
||||
|
@ -224,6 +231,7 @@ MODEL_MAPPING = OrderedDict(
|
|||
(FunnelConfig, FunnelModel),
|
||||
(LxmertConfig, LxmertModel),
|
||||
(BertGenerationConfig, BertGenerationEncoder),
|
||||
(DPRConfig, DPRQuestionEncoder),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -412,7 +420,6 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
|
|||
]
|
||||
)
|
||||
|
||||
|
||||
AUTO_MODEL_PRETRAINED_DOCSTRING = r"""
|
||||
|
||||
The model class to instantiate is selected based on the :obj:`model_type` property of the config object
|
||||
|
|
|
@ -272,6 +272,7 @@ class DPRPretrainedContextEncoder(PreTrainedModel):
|
|||
config_class = DPRConfig
|
||||
load_tf_weights = None
|
||||
base_model_prefix = "ctx_encoder"
|
||||
authorized_missing_keys = [r"position_ids"]
|
||||
|
||||
def init_weights(self):
|
||||
self.ctx_encoder.init_weights()
|
||||
|
@ -285,6 +286,7 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel):
|
|||
config_class = DPRConfig
|
||||
load_tf_weights = None
|
||||
base_model_prefix = "question_encoder"
|
||||
authorized_missing_keys = [r"position_ids"]
|
||||
|
||||
def init_weights(self):
|
||||
self.question_encoder.init_weights()
|
||||
|
@ -298,6 +300,7 @@ class DPRPretrainedReader(PreTrainedModel):
|
|||
config_class = DPRConfig
|
||||
load_tf_weights = None
|
||||
base_model_prefix = "span_predictor"
|
||||
authorized_missing_keys = [r"position_ids"]
|
||||
|
||||
def init_weights(self):
|
||||
self.span_predictor.encoder.init_weights()
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,470 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020, The RAG Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""RAG Retriever model implementation."""
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .configuration_rag import RagConfig
|
||||
from .file_utils import cached_path, is_datasets_available, is_faiss_available, is_remote_url
|
||||
from .tokenization_rag import RagTokenizer
|
||||
from .tokenization_utils_base import BatchEncoding
|
||||
from .utils import logging
|
||||
|
||||
|
||||
if is_datasets_available() and is_faiss_available():
|
||||
from datasets import load_dataset
|
||||
|
||||
import faiss
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
LEGACY_INDEX_PATH = "https://storage.googleapis.com/huggingface-nlp/datasets/wiki_dpr/"
|
||||
|
||||
|
||||
class Index:
|
||||
"""
|
||||
A base class for the Indices encapsulated by the :class:`~transformers.RagRetriever`.
|
||||
"""
|
||||
|
||||
def get_doc_dicts(self, doc_ids: np.ndarray) -> List[dict]:
|
||||
"""
|
||||
Returns a list of dictionaries, containing titles and text of the retrieved documents.
|
||||
|
||||
Args:
|
||||
doc_ids (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs)`):
|
||||
A tensor of document indices.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_top_docs(self, question_hidden_states: np.ndarray, n_docs=5) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
For each query in the batch, retrieves ``n_docs`` documents.
|
||||
|
||||
Args:
|
||||
question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size):
|
||||
An array of query vectors.
|
||||
n_docs (:obj:`int`):
|
||||
The number of docs retrieved per query.
|
||||
|
||||
Returns:
|
||||
:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs)`: A tensor of indices of retrieved documents.
|
||||
:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`: A tensor of vector representations of retrieved documents.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def is_initialized(self):
|
||||
"""
|
||||
Returns :obj:`True` if index is already initialized.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def init_index(self):
|
||||
"""
|
||||
A function responsible for loading the index into memory. Should be called only once per training run of a RAG model.
|
||||
E.g. if the model is trained on multiple GPUs in a distributed setup, only one of the workers will load the index.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LegacyIndex:
|
||||
"""
|
||||
An index which can be deserialized from the files built using https://github.com/facebookresearch/DPR.
|
||||
We use default faiss index parameters as specified in that repository.
|
||||
|
||||
Args:
|
||||
vector_size (:obj:`int`):
|
||||
The dimension of indexed vectors.
|
||||
index_path (:obj:`str`):
|
||||
A path to a `directory` containing index files compatible with
|
||||
:class:`~transformers.retrieval_rag.LegacyIndex`
|
||||
"""
|
||||
|
||||
INDEX_FILENAME = "hf_bert_base.hnswSQ8_correct_phi_128.c_index"
|
||||
PASSAGE_FILENAME = "psgs_w100.tsv.pkl"
|
||||
|
||||
def __init__(self, vector_size, index_path):
|
||||
self.index_id_to_db_id = []
|
||||
self.index_path = index_path
|
||||
self.passages = self._load_passages()
|
||||
self.vector_size = vector_size
|
||||
self.index = None
|
||||
self._index_initialize = False
|
||||
|
||||
def _resolve_path(self, index_path, filename):
|
||||
assert os.path.isdir(index_path) or is_remote_url(index_path), "Please specify a valid ``index_path``."
|
||||
archive_file = os.path.join(index_path, filename)
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
resolved_archive_file = cached_path(archive_file)
|
||||
if resolved_archive_file is None:
|
||||
raise EnvironmentError
|
||||
except EnvironmentError:
|
||||
msg = (
|
||||
f"Can't load '{archive_file}'. Make sure that:\n\n"
|
||||
f"- '{index_path}' is a correct remote path to a directory containing a file named {filename}"
|
||||
f"- or '{index_path}' is the correct path to a directory containing a file named {filename}.\n\n"
|
||||
)
|
||||
raise EnvironmentError(msg)
|
||||
if resolved_archive_file == archive_file:
|
||||
logger.info("loading file {}".format(archive_file))
|
||||
else:
|
||||
logger.info("loading file {} from cache at {}".format(archive_file, resolved_archive_file))
|
||||
return resolved_archive_file
|
||||
|
||||
def _load_passages(self):
|
||||
logger.info("Loading passages from {}".format(self.index_path))
|
||||
passages_path = self._resolve_path(self.index_path, self.PASSAGE_FILENAME)
|
||||
with open(passages_path, "rb") as passages_file:
|
||||
passages = pickle.load(passages_file)
|
||||
return passages
|
||||
|
||||
def _deserialize_index(self):
|
||||
logger.info("Loading index from {}".format(self.index_path))
|
||||
resolved_index_path = self._resolve_path(self.index_path, self.INDEX_FILENAME + ".index.dpr")
|
||||
self.index = faiss.read_index(resolved_index_path)
|
||||
resolved_meta_path = self._resolve_path(self.index_path, self.INDEX_FILENAME + ".index_meta.dpr")
|
||||
with open(resolved_meta_path, "rb") as metadata_file:
|
||||
self.index_id_to_db_id = pickle.load(metadata_file)
|
||||
assert (
|
||||
len(self.index_id_to_db_id) == self.index.ntotal
|
||||
), "Deserialized index_id_to_db_id should match faiss index size"
|
||||
|
||||
def is_initialized(self):
|
||||
return self._index_initialize
|
||||
|
||||
def init_index(self):
|
||||
index = faiss.IndexHNSWFlat(self.vector_size + 1, 512)
|
||||
index.hnsw.efSearch = 128
|
||||
index.hnsw.efConstruction = 200
|
||||
self.index = index
|
||||
self._deserialize_index()
|
||||
self._index_initialize = True
|
||||
|
||||
def get_doc_dicts(self, doc_ids: np.array):
|
||||
doc_list = []
|
||||
for doc_ids_i in doc_ids:
|
||||
ids = [str(int(doc_id)) for doc_id in doc_ids_i]
|
||||
docs = [self.passages[doc_id] for doc_id in ids]
|
||||
doc_list.append(docs)
|
||||
doc_dicts = []
|
||||
for docs in doc_list:
|
||||
doc_dict = {}
|
||||
doc_dict["title"] = [doc[1] for doc in docs]
|
||||
doc_dict["text"] = [doc[0] for doc in docs]
|
||||
doc_dicts.append(doc_dict)
|
||||
return doc_dicts
|
||||
|
||||
def get_top_docs(self, question_hidden_states: np.ndarray, n_docs=5) -> Tuple[np.ndarray, np.ndarray]:
|
||||
aux_dim = np.zeros(len(question_hidden_states), dtype="float32").reshape(-1, 1)
|
||||
query_nhsw_vectors = np.hstack((question_hidden_states, aux_dim))
|
||||
_, docs_ids = self.index.search(query_nhsw_vectors, n_docs)
|
||||
vectors = [[self.index.reconstruct(int(doc_id))[:-1] for doc_id in doc_ids] for doc_ids in docs_ids]
|
||||
ids = [[int(self.index_id_to_db_id[doc_id]) for doc_id in doc_ids] for doc_ids in docs_ids]
|
||||
return np.array(ids), np.array(vectors)
|
||||
|
||||
|
||||
class HFIndex:
|
||||
"""
|
||||
A wrapper around an instance of :class:`~datasets.Datasets`. If ``index_path`` is set to ``None``,
|
||||
we load the pre-computed index available with the :class:`~datasets.arrow_dataset.Dataset`, otherwise, we load the index from the indicated path on disk.
|
||||
|
||||
Args:
|
||||
dataset (:obj:`str`, optional, defaults to ``wiki_dpr``):
|
||||
A datatset identifier of the indexed dataset on HuggingFace AWS bucket (list all available datasets and ids with ``datasets.list_datasets()``).
|
||||
dataset_split (:obj:`str`, optional, defaults to ``train``)
|
||||
Which split of the ``dataset`` to load.
|
||||
index_name (:obj:`str`, optional, defaults to ``train``)
|
||||
The index_name of the index associated with the ``dataset``. The index loaded from ``index_path`` will be saved under this name.
|
||||
index_path (:obj:`str`, optional, defaults to ``None``)
|
||||
The path to the serialized faiss index on disk.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_name: str,
|
||||
dataset_split: str,
|
||||
index_name: str,
|
||||
index_path: Optional[str] = None,
|
||||
use_dummy_dataset=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.dataset_name = dataset_name
|
||||
self.dataset_split = dataset_split
|
||||
self.index_name = index_name
|
||||
self.index_path = index_path
|
||||
self.use_dummy_dataset = use_dummy_dataset
|
||||
self._index_initialize = False
|
||||
|
||||
logger.info("Loading passages from {}".format(self.dataset_name))
|
||||
self.dataset = load_dataset(
|
||||
self.dataset_name, with_index=False, split=self.dataset_split, dummy=self.use_dummy_dataset
|
||||
)
|
||||
|
||||
def is_initialized(self):
|
||||
return self._index_initialize
|
||||
|
||||
def init_index(self):
|
||||
if self.index_path is not None:
|
||||
logger.info("Loading index from {}".format(self.index_path))
|
||||
self.index.load_faiss_index(index_name=self.index_name, file=self.index_path)
|
||||
else:
|
||||
logger.info("Loading index from {}".format(self.dataset_name + " with index name " + self.index_name))
|
||||
self.dataset = load_dataset(
|
||||
self.dataset_name,
|
||||
with_embeddings=True,
|
||||
with_index=True,
|
||||
split=self.dataset_split,
|
||||
index_name=self.index_name,
|
||||
dummy=self.use_dummy_dataset,
|
||||
)
|
||||
self._index_initialize = True
|
||||
|
||||
def get_doc_dicts(self, doc_ids: np.ndarray) -> List[dict]:
|
||||
return [self.dataset[doc_ids[i].tolist()] for i in range(doc_ids.shape[0])]
|
||||
|
||||
def get_top_docs(self, question_hidden_states: np.ndarray, n_docs=5) -> Tuple[np.ndarray, np.ndarray]:
|
||||
_, docs = self.dataset.get_nearest_examples_batch("embeddings", question_hidden_states, n_docs)
|
||||
ids = [[int(i) for i in doc["id"]] for doc in docs]
|
||||
vectors = [doc["embeddings"] for doc in docs]
|
||||
return np.array(ids), np.array(vectors) # shapes (batch_size, n_docs) and (batch_size, n_docs, d)
|
||||
|
||||
|
||||
class RagRetriever:
|
||||
"""
|
||||
Retriever used to get documents from vector queries.
|
||||
It retrieves the documents embeddings as well as the documents contents, and it formats them to be used with a RagModel.
|
||||
|
||||
Args:
|
||||
config (:class:`~transformers.RagConfig`):
|
||||
The configuration of the RAG model this Retriever is used with. Contains parameters indicating which ``Index`` to build.
|
||||
question_encoder_tokenizer (:class:`~transformers.PreTrainedTokenizer`):
|
||||
The tokenizer that was used to tokenize the question.
|
||||
It is used to decode the question and then use the generator_tokenizer.
|
||||
generator_tokenizer (:class:`~transformers.PreTrainedTokenizer`):
|
||||
The tokenizer used for the generator part of the RagModel.
|
||||
"""
|
||||
|
||||
_init_retrieval = True
|
||||
|
||||
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer):
|
||||
super().__init__()
|
||||
self.index = (
|
||||
LegacyIndex(
|
||||
config.retrieval_vector_size,
|
||||
config.index_path or LEGACY_INDEX_PATH,
|
||||
)
|
||||
if config.index_name == "legacy"
|
||||
else HFIndex(
|
||||
config.dataset, config.dataset_split, config.index_name, config.index_path, config.use_dummy_dataset
|
||||
)
|
||||
)
|
||||
self.generator_tokenizer = generator_tokenizer
|
||||
self.question_encoder_tokenizer = question_encoder_tokenizer
|
||||
|
||||
self.n_docs = config.n_docs
|
||||
self.batch_size = config.retrieval_batch_size
|
||||
|
||||
self.config = config
|
||||
if self._init_retrieval:
|
||||
self.init_retrieval()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, retriever_name_or_path, **kwargs):
|
||||
config = RagConfig.from_pretrained(retriever_name_or_path, **kwargs)
|
||||
rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config)
|
||||
question_encoder_tokenizer = rag_tokenizer.question_encoder
|
||||
generator_tokenizer = rag_tokenizer.generator
|
||||
return cls(
|
||||
config, question_encoder_tokenizer=question_encoder_tokenizer, generator_tokenizer=generator_tokenizer
|
||||
)
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
self.config.save_pretrained(save_directory)
|
||||
rag_tokenizer = RagTokenizer(
|
||||
question_encoder_tokenizer=self.question_encoder_tokenizer,
|
||||
generator_tokenizer=self.generator_tokenizer,
|
||||
)
|
||||
rag_tokenizer.save_pretrained(save_directory)
|
||||
|
||||
def init_retrieval(self):
|
||||
"""
|
||||
Retriever initalization function. It loads the index into memory.
|
||||
"""
|
||||
|
||||
logger.info("initializing retrieval")
|
||||
self.index.init_index()
|
||||
|
||||
def postprocess_docs(self, docs, input_strings, prefix, n_docs, return_tensors=None):
|
||||
r"""
|
||||
Postprocessing retrieved ``docs`` and combining them with ``input_strings``.
|
||||
|
||||
Args:
|
||||
doc_scores (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs)`):
|
||||
Retrieval scores of respective docs - passed for logging.
|
||||
docs (:obj:`dict`):
|
||||
Retrieved documents.
|
||||
input_strings (:obj:`str`):
|
||||
Input strings decoded by ``preprocess_query``.
|
||||
prefix (:obj:`str`):
|
||||
Prefix added at the beginning of each input, typically used with T5-based models.
|
||||
|
||||
Return:
|
||||
:obj:`tuple(tensors)`:
|
||||
a tuple consisting of two elements: contextualized ``input_ids`` and a compatible ``attention_mask``.
|
||||
"""
|
||||
|
||||
def cat_input_and_doc(doc_title, doc_text, input_string, prefix):
|
||||
# TODO(Patrick): if we train more RAG models, I want to put the input first to take advantage of effortless truncation
|
||||
# TODO(piktus): better handling of truncation
|
||||
if doc_title.startswith('"'):
|
||||
doc_title = doc_title[1:]
|
||||
if doc_title.endswith('"'):
|
||||
doc_title = doc_title[:-1]
|
||||
if prefix is None:
|
||||
prefix = ""
|
||||
out = (prefix + doc_title + self.config.title_sep + doc_text + self.config.doc_sep + input_string).replace(
|
||||
" ", " "
|
||||
)
|
||||
return out
|
||||
|
||||
rag_input_strings = [
|
||||
cat_input_and_doc(
|
||||
docs[i]["title"][j],
|
||||
docs[i]["text"][j],
|
||||
input_strings[i],
|
||||
prefix,
|
||||
)
|
||||
for i in range(len(docs))
|
||||
for j in range(n_docs)
|
||||
]
|
||||
|
||||
contextualized_inputs = self.generator_tokenizer.batch_encode_plus(
|
||||
rag_input_strings,
|
||||
max_length=self.config.max_combined_length,
|
||||
return_tensors=return_tensors,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
)
|
||||
|
||||
return contextualized_inputs["input_ids"], contextualized_inputs["attention_mask"]
|
||||
|
||||
def _chunk_tensor(self, t: Iterable, chunk_size: int) -> List[Iterable]:
|
||||
return [t[i : i + chunk_size] for i in range(0, len(t), chunk_size)]
|
||||
|
||||
def _main_retrieve(self, question_hidden_states: np.ndarray, n_docs: int) -> Tuple[np.ndarray, np.ndarray]:
|
||||
question_hidden_states_batched = self._chunk_tensor(question_hidden_states, self.batch_size)
|
||||
ids_batched = []
|
||||
vectors_batched = []
|
||||
for question_hidden_states in question_hidden_states_batched:
|
||||
start_time = time.time()
|
||||
ids, vectors = self.index.get_top_docs(question_hidden_states, n_docs)
|
||||
logger.debug(
|
||||
"index search time: {} sec, batch size {}".format(
|
||||
time.time() - start_time, question_hidden_states.shape
|
||||
)
|
||||
)
|
||||
ids_batched.extend(ids)
|
||||
vectors_batched.extend(vectors)
|
||||
return np.array(ids_batched), np.array(
|
||||
vectors_batched
|
||||
) # shapes (batch_size, n_docs) and (batch_size, n_docs, d)
|
||||
|
||||
def retrieve(self, question_hidden_states: np.ndarray, n_docs: int) -> Tuple[np.ndarray, List[dict]]:
|
||||
"""
|
||||
Retrieves documents for specified ``question_hidden_states``.
|
||||
|
||||
Args:
|
||||
question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`):
|
||||
A batch of query vectors to retrieve with.
|
||||
n_docs (:obj:`int`):
|
||||
The number of docs retrieved per query.
|
||||
|
||||
Return:
|
||||
retrieved_doc_embeds (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs, dim)`
|
||||
The retrieval embeddings of the retrieved docs per query.
|
||||
doc_ids (:obj:`np.ndarray` of shape :obj:`batch_size, n_docs`)
|
||||
The ids of the documents in the index
|
||||
doc_dicts (:obj:`List[dict]`):
|
||||
The retrieved_doc_embeds examples per query.
|
||||
"""
|
||||
|
||||
doc_ids, retrieved_doc_embeds = self._main_retrieve(question_hidden_states, n_docs)
|
||||
return retrieved_doc_embeds, doc_ids, self.index.get_doc_dicts(doc_ids)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
question_input_ids: List[List[int]],
|
||||
question_hidden_states: np.ndarray,
|
||||
prefix=None,
|
||||
n_docs=None,
|
||||
return_tensors=None,
|
||||
) -> BatchEncoding:
|
||||
"""
|
||||
Retrieves documents for specified :obj:`question_hidden_states`.
|
||||
|
||||
Args:
|
||||
question_input_ids: (:obj:`List[List[int]]`) batch of input ids
|
||||
question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`:
|
||||
A batch of query vectors to retrieve with.
|
||||
prefix: (:obj:`str`, `optional`):
|
||||
The prefix used by the generator's tokenizer.
|
||||
n_docs (:obj:`int`, `optional`):
|
||||
The number of docs retrieved per query.
|
||||
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"):
|
||||
If set, will return tensors instead of list of python integers. Acceptable values are:
|
||||
|
||||
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
|
||||
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
|
||||
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
|
||||
|
||||
Output:
|
||||
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
|
||||
|
||||
- **context_input_ids** -- List of token ids to be fed to a model.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
- **context_attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||
:obj:`return_attention_mask=True` or if `"attention_mask"` is in :obj:`self.model_input_names`).
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
- **retrieved_doc_embeds** -- List of embeddings of the retrieved documents
|
||||
- **doc_ids** -- List of ids of the retrieved documents
|
||||
"""
|
||||
|
||||
n_docs = n_docs if n_docs is not None else self.n_docs
|
||||
prefix = prefix if prefix is not None else self.config.generator.prefix
|
||||
retrieved_doc_embeds, doc_ids, docs = self.retrieve(question_hidden_states, n_docs)
|
||||
|
||||
input_strings = self.question_encoder_tokenizer.batch_decode(question_input_ids, skip_special_tokens=True)
|
||||
context_input_ids, context_attention_mask = self.postprocess_docs(
|
||||
docs, input_strings, prefix, n_docs, return_tensors=return_tensors
|
||||
)
|
||||
|
||||
return BatchEncoding(
|
||||
{
|
||||
"context_input_ids": context_input_ids,
|
||||
"context_attention_mask": context_attention_mask,
|
||||
"retrieved_doc_embeds": retrieved_doc_embeds,
|
||||
"doc_ids": doc_ids,
|
||||
},
|
||||
tensor_type=return_tensors,
|
||||
)
|
|
@ -10,7 +10,7 @@ from distutils.util import strtobool
|
|||
from io import StringIO
|
||||
from pathlib import Path
|
||||
|
||||
from .file_utils import _tf_available, _torch_available, _torch_tpu_available
|
||||
from .file_utils import _datasets_available, _faiss_available, _tf_available, _torch_available, _torch_tpu_available
|
||||
|
||||
|
||||
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
|
||||
|
@ -161,6 +161,21 @@ def require_torch_and_cuda(test_case):
|
|||
return test_case
|
||||
|
||||
|
||||
def require_datasets(test_case):
|
||||
"""Decorator marking a test that requires datasets."""
|
||||
|
||||
if not _datasets_available:
|
||||
test_case = unittest.skip("test requires Datasets")(test_case)
|
||||
return test_case
|
||||
|
||||
|
||||
def require_faiss(test_case):
|
||||
"""Decorator marking a test that requires faiss."""
|
||||
if not _faiss_available:
|
||||
test_case = unittest.skip("test requires Faiss")(test_case)
|
||||
return test_case
|
||||
|
||||
|
||||
def get_tests_dir():
|
||||
"""
|
||||
returns the full path to the `tests` dir, so that the tests can be invoked from anywhere
|
||||
|
|
|
@ -26,6 +26,7 @@ from .configuration_auto import (
|
|||
CamembertConfig,
|
||||
CTRLConfig,
|
||||
DistilBertConfig,
|
||||
DPRConfig,
|
||||
ElectraConfig,
|
||||
EncoderDecoderConfig,
|
||||
FlaubertConfig,
|
||||
|
@ -40,6 +41,7 @@ from .configuration_auto import (
|
|||
MobileBertConfig,
|
||||
OpenAIGPTConfig,
|
||||
PegasusConfig,
|
||||
RagConfig,
|
||||
ReformerConfig,
|
||||
RetriBertConfig,
|
||||
RobertaConfig,
|
||||
|
@ -60,6 +62,7 @@ from .tokenization_bertweet import BertweetTokenizer
|
|||
from .tokenization_camembert import CamembertTokenizer
|
||||
from .tokenization_ctrl import CTRLTokenizer
|
||||
from .tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFast
|
||||
from .tokenization_dpr import DPRQuestionEncoderTokenizer, DPRQuestionEncoderTokenizerFast
|
||||
from .tokenization_electra import ElectraTokenizer, ElectraTokenizerFast
|
||||
from .tokenization_flaubert import FlaubertTokenizer
|
||||
from .tokenization_fsmt import FSMTTokenizer
|
||||
|
@ -74,6 +77,7 @@ from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFas
|
|||
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
|
||||
from .tokenization_pegasus import PegasusTokenizer
|
||||
from .tokenization_phobert import PhobertTokenizer
|
||||
from .tokenization_rag import RagTokenizer
|
||||
from .tokenization_reformer import ReformerTokenizer
|
||||
from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast
|
||||
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
|
||||
|
@ -110,6 +114,7 @@ TOKENIZER_MAPPING = OrderedDict(
|
|||
(FunnelConfig, (FunnelTokenizer, FunnelTokenizerFast)),
|
||||
(LxmertConfig, (LxmertTokenizer, LxmertTokenizerFast)),
|
||||
(LayoutLMConfig, (LayoutLMTokenizer, LayoutLMTokenizerFast)),
|
||||
(DPRConfig, (DPRQuestionEncoderTokenizer, DPRQuestionEncoderTokenizerFast)),
|
||||
(BertConfig, (BertTokenizer, BertTokenizerFast)),
|
||||
(OpenAIGPTConfig, (OpenAIGPTTokenizer, OpenAIGPTTokenizerFast)),
|
||||
(GPT2Config, (GPT2Tokenizer, GPT2TokenizerFast)),
|
||||
|
@ -121,6 +126,7 @@ TOKENIZER_MAPPING = OrderedDict(
|
|||
(FSMTConfig, (FSMTTokenizer, None)),
|
||||
(BertGenerationConfig, (BertGenerationTokenizer, None)),
|
||||
(LayoutLMConfig, (LayoutLMTokenizer, None)),
|
||||
(RagConfig, (RagTokenizer, None)),
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
@ -0,0 +1,160 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020, The RAG Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization classes for RAG."""
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from .configuration_rag import RagConfig
|
||||
from .tokenization_utils_base import BatchEncoding
|
||||
from .utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class RagTokenizer:
|
||||
def __init__(self, question_encoder, generator):
|
||||
self.question_encoder = question_encoder
|
||||
self.generator = generator
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
if os.path.isfile(save_directory):
|
||||
raise ValueError("Provided path ({}) should be a directory, not a file".format(save_directory))
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
question_encoder_path = os.path.join(save_directory, "question_encoder_tokenizer")
|
||||
generator_path = os.path.join(save_directory, "generator_tokenizer")
|
||||
self.question_encoder.save_pretrained(question_encoder_path)
|
||||
self.generator.save_pretrained(generator_path)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
||||
# dynamically import AutoTokenizer
|
||||
from .tokenization_auto import AutoTokenizer
|
||||
|
||||
config = kwargs.pop("config", None)
|
||||
|
||||
if config is None:
|
||||
config = RagConfig.from_pretrained(pretrained_model_name_or_path)
|
||||
|
||||
question_encoder_path = os.path.join(pretrained_model_name_or_path, "question_encoder_tokenizer")
|
||||
generator_path = os.path.join(pretrained_model_name_or_path, "generator_tokenizer")
|
||||
question_encoder = AutoTokenizer.from_pretrained(question_encoder_path, config=config.question_encoder)
|
||||
generator = AutoTokenizer.from_pretrained(generator_path, config=config.generator)
|
||||
return cls(question_encoder=question_encoder, generator=generator)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.question_encoder(*args, **kwargs)
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
return self.generator.batch_decode(*args, **kwargs)
|
||||
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
padding: str = "longest",
|
||||
return_tensors: str = "np",
|
||||
truncation=True,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
r"""
|
||||
|
||||
Prepare a batch that can be passed directly to an instance of :class:`~transformers.RagModel`.
|
||||
|
||||
Args:
|
||||
src_texts: (:obj:`List[str]`):
|
||||
List of documents to summarize or source language texts.
|
||||
tgt_texts: (:obj:`List[str]`, `optional`):
|
||||
List of summaries or target language texts.
|
||||
max_length (:obj:`int`, `optional`):
|
||||
Controls the maximum length for encoder inputs (documents to summarize or source language texts).
|
||||
If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum
|
||||
length is required by one of the truncation/padding parameters. If the model has no specific maximum
|
||||
input length (like XLNet) truncation/padding to a maximum length will be deactivated.
|
||||
max_target_length (:obj:`int`, `optional`):
|
||||
Controls the maximum length of decoder inputs (target language texts or summaries).
|
||||
If left unset or set to :obj:`None`, this will use the max_length value.
|
||||
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
|
||||
Activates and controls padding. Accepts the following values:
|
||||
|
||||
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
|
||||
single sequence if provided).
|
||||
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
||||
maximum acceptable input length for the model if that argument is not provided.
|
||||
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
||||
different lengths).
|
||||
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"):
|
||||
If set, will return tensors instead of list of python integers. Acceptable values are:
|
||||
|
||||
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
|
||||
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
|
||||
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
|
||||
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
|
||||
Activates and controls truncation. Accepts the following values:
|
||||
|
||||
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
|
||||
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
|
||||
provided. This will truncate token by token, removing a token from the longest sequence in the pair
|
||||
if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
|
||||
the maximum acceptable input length for the model if that argument is not provided. This will only
|
||||
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
|
||||
to the maximum acceptable input length for the model if that argument is not provided. This will only
|
||||
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
|
||||
sequence lengths greater than the model maximum admissible input size).
|
||||
**kwargs:
|
||||
Additional keyword arguments passed along to :obj:`self.__call__`.
|
||||
|
||||
Returns:
|
||||
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to the encoder.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
|
||||
- **labels** -- List of token ids for tgt_texts
|
||||
|
||||
The full set of keys ``[input_ids, attention_mask, labels]``,
|
||||
will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys.
|
||||
"""
|
||||
if max_length is None:
|
||||
max_length = self.question_encoder.model_max_length
|
||||
model_inputs: BatchEncoding = self.question_encoder(
|
||||
src_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
# Process tgt_texts
|
||||
if max_target_length is None:
|
||||
max_target_length = self.generator.model_max_length
|
||||
labels = self.generator(
|
||||
tgt_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
max_length=max_target_length,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)["input_ids"]
|
||||
model_inputs["labels"] = labels
|
||||
return model_inputs
|
|
@ -0,0 +1,885 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020, The RAG Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.file_utils import cached_property, is_datasets_available, is_faiss_available, is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
from transformers.tokenization_bart import BartTokenizer
|
||||
from transformers.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
|
||||
from transformers.tokenization_dpr import DPRQuestionEncoderTokenizer
|
||||
from transformers.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
|
||||
from transformers.tokenization_t5 import T5Tokenizer
|
||||
|
||||
from .test_modeling_bart import ModelTester as BartModelTester
|
||||
from .test_modeling_dpr import DPRModelTester
|
||||
from .test_modeling_t5 import T5ModelTester
|
||||
|
||||
|
||||
TOLERANCE = 1e-3
|
||||
|
||||
T5_SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
|
||||
|
||||
if is_torch_available() and is_datasets_available() and is_faiss_available():
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
|
||||
import faiss
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoModelForSeq2SeqLM,
|
||||
RagConfig,
|
||||
RagModel,
|
||||
RagRetriever,
|
||||
RagSequenceForGeneration,
|
||||
RagTokenForGeneration,
|
||||
)
|
||||
from transformers.modeling_outputs import BaseModelOutput
|
||||
|
||||
|
||||
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
||||
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
||||
if a is None and b is None:
|
||||
return True
|
||||
try:
|
||||
if torch.allclose(a, b, atol=atol):
|
||||
return True
|
||||
raise
|
||||
except Exception:
|
||||
msg = "{} != {}".format(a, b)
|
||||
if prefix:
|
||||
msg = prefix + ": " + msg
|
||||
raise AssertionError(msg)
|
||||
|
||||
|
||||
def require_retrieval(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a set of dependencies necessary for pefrorm retrieval with
|
||||
:class:`~transformers.RagRetriever`.
|
||||
|
||||
These tests are skipped when respective libraries are not installed.
|
||||
|
||||
"""
|
||||
if not (is_torch_available() and is_datasets_available() and is_faiss_available()):
|
||||
test_case = unittest.skip("test requires PyTorch")(test_case)
|
||||
return test_case
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_retrieval
|
||||
class RagTestMixin:
|
||||
|
||||
all_model_classes = (
|
||||
(RagModel, RagTokenForGeneration, RagSequenceForGeneration)
|
||||
if is_torch_available() and is_datasets_available() and is_faiss_available()
|
||||
else ()
|
||||
)
|
||||
|
||||
retrieval_vector_size = 32
|
||||
n_docs = 2
|
||||
max_combined_length = 16
|
||||
|
||||
def setUp(self):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
|
||||
# DPR tok
|
||||
vocab_tokens = [
|
||||
"[UNK]",
|
||||
"[CLS]",
|
||||
"[SEP]",
|
||||
"[PAD]",
|
||||
"[MASK]",
|
||||
"want",
|
||||
"##want",
|
||||
"##ed",
|
||||
"wa",
|
||||
"un",
|
||||
"runn",
|
||||
"##ing",
|
||||
",",
|
||||
"low",
|
||||
"lowest",
|
||||
]
|
||||
dpr_tokenizer_path = os.path.join(self.tmpdirname, "dpr_tokenizer")
|
||||
os.makedirs(dpr_tokenizer_path, exist_ok=True)
|
||||
self.vocab_file = os.path.join(dpr_tokenizer_path, DPR_VOCAB_FILES_NAMES["vocab_file"])
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
|
||||
# BART tok
|
||||
vocab = [
|
||||
"l",
|
||||
"o",
|
||||
"w",
|
||||
"e",
|
||||
"r",
|
||||
"s",
|
||||
"t",
|
||||
"i",
|
||||
"d",
|
||||
"n",
|
||||
"\u0120",
|
||||
"\u0120l",
|
||||
"\u0120n",
|
||||
"\u0120lo",
|
||||
"\u0120low",
|
||||
"er",
|
||||
"\u0120lowest",
|
||||
"\u0120newer",
|
||||
"\u0120wider",
|
||||
"<unk>",
|
||||
]
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
|
||||
self.special_tokens_map = {"unk_token": "<unk>"}
|
||||
|
||||
bart_tokenizer_path = os.path.join(self.tmpdirname, "bart_tokenizer")
|
||||
os.makedirs(bart_tokenizer_path, exist_ok=True)
|
||||
self.vocab_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["vocab_file"])
|
||||
self.merges_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["merges_file"])
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as fp:
|
||||
fp.write(json.dumps(vocab_tokens) + "\n")
|
||||
with open(self.merges_file, "w", encoding="utf-8") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
t5_tokenizer = T5Tokenizer(T5_SAMPLE_VOCAB)
|
||||
t5_tokenizer_path = os.path.join(self.tmpdirname, "t5_tokenizer")
|
||||
t5_tokenizer.save_pretrained(t5_tokenizer_path)
|
||||
|
||||
@cached_property
|
||||
def dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer:
|
||||
return DPRQuestionEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer"))
|
||||
|
||||
@cached_property
|
||||
def bart_tokenizer(self) -> BartTokenizer:
|
||||
return BartTokenizer.from_pretrained(os.path.join(self.tmpdirname, "bart_tokenizer"))
|
||||
|
||||
@cached_property
|
||||
def t5_tokenizer(self) -> BartTokenizer:
|
||||
return T5Tokenizer.from_pretrained(os.path.join(self.tmpdirname, "t5_tokenizer"))
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def get_retriever(self, config):
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"id": ["0", "1"],
|
||||
"text": ["foo", "bar"],
|
||||
"title": ["Foo", "Bar"],
|
||||
"embeddings": [np.ones(self.retrieval_vector_size), 2 * np.ones(self.retrieval_vector_size)],
|
||||
}
|
||||
)
|
||||
dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT)
|
||||
tokenizer = self.bart_tokenizer if config.generator.model_type == "bart" else self.t5_tokenizer
|
||||
with patch("transformers.retrieval_rag.load_dataset") as mock_load_dataset:
|
||||
mock_load_dataset.return_value = dataset
|
||||
retriever = RagRetriever(
|
||||
config,
|
||||
question_encoder_tokenizer=self.dpr_tokenizer,
|
||||
generator_tokenizer=tokenizer,
|
||||
)
|
||||
return retriever
|
||||
|
||||
def check_model_with_retriever(
|
||||
self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs
|
||||
):
|
||||
self.assertIsNotNone(config.question_encoder)
|
||||
self.assertIsNotNone(config.generator)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config, retriever=self.get_retriever(config)).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
self.assertTrue(model.config.is_encoder_decoder)
|
||||
|
||||
outputs = model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
|
||||
# logits
|
||||
self.assertEqual(
|
||||
outputs.logits.shape,
|
||||
(self.n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size),
|
||||
)
|
||||
# generator encoder last hidden states
|
||||
self.assertEqual(
|
||||
outputs.generator_enc_last_hidden_state.shape,
|
||||
(self.n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size),
|
||||
)
|
||||
# doc scores
|
||||
self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs))
|
||||
|
||||
def check_model_generate(
|
||||
self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs
|
||||
):
|
||||
self.assertIsNotNone(config.question_encoder)
|
||||
self.assertIsNotNone(config.generator)
|
||||
|
||||
for model_class in self.all_model_classes[1:]:
|
||||
model = model_class(config, retriever=self.get_retriever(config)).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
self.assertTrue(model.config.is_encoder_decoder)
|
||||
|
||||
outputs = model.generate(
|
||||
input_ids=input_ids,
|
||||
num_beams=2,
|
||||
num_return_sequences=2,
|
||||
decoder_start_token_id=config.generator.eos_token_id,
|
||||
)
|
||||
|
||||
self.assertIsNotNone(outputs)
|
||||
|
||||
def check_model_without_retriever(
|
||||
self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs
|
||||
):
|
||||
self.assertIsNotNone(config.question_encoder)
|
||||
self.assertIsNotNone(config.generator)
|
||||
|
||||
retriever = self.get_retriever(config)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval()
|
||||
self.assertTrue(model.config.is_encoder_decoder)
|
||||
|
||||
question_hidden_states = model.question_encoder(input_ids, attention_mask=attention_mask)[0]
|
||||
|
||||
out = retriever(
|
||||
input_ids,
|
||||
question_hidden_states.cpu().detach().to(torch.float32).numpy(),
|
||||
prefix=config.generator.prefix,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
context_input_ids, context_attention_mask, retrieved_doc_embeds = (
|
||||
out["context_input_ids"],
|
||||
out["context_attention_mask"],
|
||||
out["retrieved_doc_embeds"],
|
||||
)
|
||||
|
||||
# cast
|
||||
retrieved_doc_embeds = retrieved_doc_embeds.to(question_hidden_states)
|
||||
context_input_ids = context_input_ids.to(input_ids)
|
||||
context_attention_mask = context_attention_mask.to(input_ids)
|
||||
|
||||
# compute doc_scores
|
||||
doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), retrieved_doc_embeds.transpose(1, 2)).squeeze(
|
||||
1
|
||||
)
|
||||
|
||||
outputs = model(
|
||||
context_input_ids=context_input_ids,
|
||||
context_attention_mask=context_attention_mask,
|
||||
doc_scores=doc_scores,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
|
||||
# logits
|
||||
self.assertEqual(
|
||||
outputs.logits.shape,
|
||||
(self.n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size),
|
||||
)
|
||||
# generator encoder last hidden states
|
||||
self.assertEqual(
|
||||
outputs.generator_enc_last_hidden_state.shape,
|
||||
(self.n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size),
|
||||
)
|
||||
# doc scores
|
||||
self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs))
|
||||
|
||||
def check_model_with_encoder_outputs(
|
||||
self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs
|
||||
):
|
||||
self.assertIsNotNone(config.question_encoder)
|
||||
self.assertIsNotNone(config.generator)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config, retriever=self.get_retriever(config)).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
self.assertTrue(model.config.is_encoder_decoder)
|
||||
|
||||
outputs = model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
|
||||
encoder_outputs = BaseModelOutput(outputs.generator_enc_last_hidden_state)
|
||||
|
||||
# run only generator
|
||||
outputs = model(
|
||||
encoder_outputs=encoder_outputs,
|
||||
doc_scores=outputs.doc_scores,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
|
||||
# logits
|
||||
self.assertEqual(
|
||||
outputs.logits.shape,
|
||||
(self.n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size),
|
||||
)
|
||||
# generator encoder last hidden states
|
||||
self.assertEqual(
|
||||
outputs.generator_enc_last_hidden_state.shape,
|
||||
(self.n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size),
|
||||
)
|
||||
# doc scores
|
||||
self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs))
|
||||
|
||||
def test_model_with_retriever(self):
|
||||
inputs_dict = self.config_and_inputs
|
||||
self.check_model_with_retriever(**inputs_dict)
|
||||
|
||||
def test_model_without_retriever(self):
|
||||
inputs_dict = self.config_and_inputs
|
||||
self.check_model_without_retriever(**inputs_dict)
|
||||
|
||||
def test_model_with_encoder_outputs(self):
|
||||
inputs_dict = self.config_and_inputs
|
||||
self.check_model_with_encoder_outputs(**inputs_dict)
|
||||
|
||||
def test_model_generate(self):
|
||||
inputs_dict = self.config_and_inputs
|
||||
self.check_model_generate(**inputs_dict)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_retrieval
|
||||
class RagDPRBartTest(RagTestMixin, unittest.TestCase):
|
||||
@cached_property
|
||||
def config_and_inputs(self):
|
||||
question_encoder_tester = DPRModelTester(self)
|
||||
dpr_config_and_inputs = question_encoder_tester.prepare_config_and_inputs()
|
||||
generator_tester = BartModelTester(self)
|
||||
bart_config_and_inputs = generator_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
(question_encoder_config, input_ids, _, input_mask, _, _, _) = dpr_config_and_inputs
|
||||
(generator_config, bart_inputs_dict) = bart_config_and_inputs
|
||||
decoder_input_ids, decoder_attention_mask = bart_inputs_dict["input_ids"], bart_inputs_dict["attention_mask"]
|
||||
|
||||
config = RagConfig.from_question_encoder_generator_configs(
|
||||
question_encoder_config,
|
||||
generator_config,
|
||||
n_docs=self.n_docs,
|
||||
retrieval_vector_size=self.retrieval_vector_size,
|
||||
max_combined_length=self.max_combined_length,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
return {
|
||||
"config": config,
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": input_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
}
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_retrieval
|
||||
class RagDPRT5Test(RagTestMixin, unittest.TestCase):
|
||||
@cached_property
|
||||
def config_and_inputs(self):
|
||||
question_encoder_tester = DPRModelTester(self)
|
||||
dpr_config_and_inputs = question_encoder_tester.prepare_config_and_inputs()
|
||||
generator_tester = T5ModelTester(self, vocab_size=1100, n_positions=30)
|
||||
t5_config_and_inputs = generator_tester.prepare_config_and_inputs()
|
||||
|
||||
(question_encoder_config, input_ids, _, input_mask, _, _, _) = dpr_config_and_inputs
|
||||
# import ipdb; ipdb.set_trace()
|
||||
(generator_config, _, decoder_input_ids, _, decoder_attention_mask, _) = t5_config_and_inputs
|
||||
config = RagConfig.from_question_encoder_generator_configs(
|
||||
question_encoder_config,
|
||||
generator_config,
|
||||
n_docs=self.n_docs,
|
||||
retrieval_vector_size=self.retrieval_vector_size,
|
||||
max_combined_length=self.max_combined_length,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
return {
|
||||
"config": config,
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": input_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
}
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_retrieval
|
||||
class RagModelIntegrationTests(unittest.TestCase):
|
||||
@cached_property
|
||||
def sequence_model(self):
|
||||
return (
|
||||
RagSequenceForGeneration.from_pretrained_question_encoder_generator(
|
||||
"facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large-cnn"
|
||||
)
|
||||
.to(torch_device)
|
||||
.eval()
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def token_model(self):
|
||||
return (
|
||||
RagTokenForGeneration.from_pretrained_question_encoder_generator(
|
||||
"facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large-cnn"
|
||||
)
|
||||
.to(torch_device)
|
||||
.eval()
|
||||
)
|
||||
|
||||
def get_rag_config(self):
|
||||
question_encoder_config = AutoConfig.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
|
||||
generator_config = AutoConfig.from_pretrained("facebook/bart-large-cnn")
|
||||
return RagConfig.from_question_encoder_generator_configs(
|
||||
question_encoder_config,
|
||||
generator_config,
|
||||
bos_token_id=0,
|
||||
decoder_start_token_id=2,
|
||||
eos_token_id=2,
|
||||
is_encoder_decoder=True,
|
||||
pad_token_id=1,
|
||||
vocab_size=50264,
|
||||
title_sep=" / ",
|
||||
doc_sep=" // ",
|
||||
n_docs=5,
|
||||
max_combined_length=300,
|
||||
dataset="wiki_dpr",
|
||||
dataset_split="train",
|
||||
index_name="exact",
|
||||
index_path=None,
|
||||
use_dummy_dataset=True,
|
||||
retrieval_vector_size=768,
|
||||
retrieval_batch_size=8,
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_rag_sequence_inference(self):
|
||||
rag_config = self.get_rag_config()
|
||||
rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
|
||||
rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
|
||||
"facebook/dpr-question_encoder-single-nq-base"
|
||||
)
|
||||
rag_retriever = RagRetriever(
|
||||
rag_config,
|
||||
question_encoder_tokenizer=rag_question_encoder_tokenizer,
|
||||
generator_tokenizer=rag_decoder_tokenizer,
|
||||
)
|
||||
|
||||
rag_sequence = self.sequence_model
|
||||
rag_sequence.set_retriever(rag_retriever)
|
||||
|
||||
input_ids = rag_question_encoder_tokenizer(
|
||||
"who sings does he love me with reba", return_tensors="pt"
|
||||
).input_ids
|
||||
decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="pt").input_ids
|
||||
|
||||
input_ids = input_ids.to(torch_device)
|
||||
decoder_input_ids = decoder_input_ids.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
output = rag_sequence(
|
||||
input_ids,
|
||||
labels=decoder_input_ids,
|
||||
)
|
||||
|
||||
expected_shape = torch.Size([5, 5, 50264])
|
||||
self.assertEqual(output.logits.shape, expected_shape)
|
||||
|
||||
expected_doc_scores = torch.tensor([[75.0286, 74.4998, 74.0804, 74.0306, 73.9504]]).to(torch_device)
|
||||
_assert_tensors_equal(expected_doc_scores, output.doc_scores, atol=TOLERANCE)
|
||||
|
||||
expected_loss = torch.tensor([38.7446]).to(torch_device)
|
||||
_assert_tensors_equal(expected_loss, output.loss, atol=TOLERANCE)
|
||||
|
||||
@slow
|
||||
def test_rag_token_inference(self):
|
||||
rag_config = self.get_rag_config()
|
||||
rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
|
||||
rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
|
||||
"facebook/dpr-question_encoder-single-nq-base"
|
||||
)
|
||||
rag_retriever = RagRetriever(
|
||||
rag_config,
|
||||
question_encoder_tokenizer=rag_question_encoder_tokenizer,
|
||||
generator_tokenizer=rag_decoder_tokenizer,
|
||||
)
|
||||
|
||||
rag_token = self.token_model
|
||||
rag_token.set_retriever(rag_retriever)
|
||||
|
||||
input_ids = rag_question_encoder_tokenizer(
|
||||
"who sings does he love me with reba", return_tensors="pt"
|
||||
).input_ids
|
||||
decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="pt").input_ids
|
||||
|
||||
input_ids = input_ids.to(torch_device)
|
||||
decoder_input_ids = decoder_input_ids.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
output = rag_token(
|
||||
input_ids,
|
||||
labels=decoder_input_ids,
|
||||
)
|
||||
|
||||
expected_shape = torch.Size([5, 5, 50264])
|
||||
self.assertEqual(output.logits.shape, expected_shape)
|
||||
|
||||
expected_doc_scores = torch.tensor([[75.0286, 74.4998, 74.0804, 74.0306, 73.9504]]).to(torch_device)
|
||||
_assert_tensors_equal(expected_doc_scores, output.doc_scores, atol=TOLERANCE)
|
||||
|
||||
expected_loss = torch.tensor([38.7045]).to(torch_device)
|
||||
_assert_tensors_equal(expected_loss, output.loss, atol=TOLERANCE)
|
||||
|
||||
@slow
|
||||
def test_rag_token_generate_beam(self):
|
||||
rag_config = self.get_rag_config()
|
||||
rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
|
||||
rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
|
||||
"facebook/dpr-question_encoder-single-nq-base"
|
||||
)
|
||||
rag_retriever = RagRetriever(
|
||||
rag_config,
|
||||
question_encoder_tokenizer=rag_question_encoder_tokenizer,
|
||||
generator_tokenizer=rag_decoder_tokenizer,
|
||||
)
|
||||
|
||||
rag_token = self.token_model
|
||||
rag_token.set_retriever(rag_retriever)
|
||||
|
||||
input_ids = rag_question_encoder_tokenizer(
|
||||
"who sings does he love me with reba", return_tensors="pt"
|
||||
).input_ids
|
||||
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
output_ids = rag_token.generate(
|
||||
input_ids,
|
||||
decoder_start_token_id=rag_token.generator.config.decoder_start_token_id,
|
||||
num_beams=2,
|
||||
num_return_sequences=2,
|
||||
)
|
||||
# sequence generate test
|
||||
output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
||||
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
|
||||
|
||||
# Expected outputs as given by model at integration time.
|
||||
EXPECTED_OUTPUT_TEXT_1 = "The songwriting credits are credited to ABBA"
|
||||
EXPECTED_OUTPUT_TEXT_2 = 'The songwriting credits are credited to "B'
|
||||
|
||||
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
|
||||
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
|
||||
|
||||
@slow
|
||||
def test_rag_token_generate_batch(self):
|
||||
rag_config = self.get_rag_config()
|
||||
rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
|
||||
rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
|
||||
"facebook/dpr-question_encoder-single-nq-base"
|
||||
)
|
||||
rag_retriever = RagRetriever(
|
||||
rag_config,
|
||||
question_encoder_tokenizer=rag_question_encoder_tokenizer,
|
||||
generator_tokenizer=rag_decoder_tokenizer,
|
||||
)
|
||||
|
||||
rag_token = self.token_model
|
||||
rag_token.set_retriever(rag_retriever)
|
||||
|
||||
questions = [
|
||||
"who sings does he love me with reba",
|
||||
"how many pages is invisible man by ralph ellison",
|
||||
]
|
||||
input_ids = rag_question_encoder_tokenizer.batch_encode_plus(
|
||||
questions,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
).input_ids
|
||||
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
output_ids = rag_token.generate(
|
||||
input_ids,
|
||||
decoder_start_token_id=rag_token.generator.config.decoder_start_token_id,
|
||||
num_beams=4,
|
||||
num_return_sequences=1,
|
||||
max_length=10,
|
||||
)
|
||||
|
||||
# sequence generate test
|
||||
output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
||||
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
|
||||
|
||||
# Expected outputs as given by model at integration time.
|
||||
EXPECTED_OUTPUT_TEXT_1 = '"People Need Love" is the'
|
||||
EXPECTED_OUTPUT_TEXT_2 = '"How many pages is invisible man'
|
||||
|
||||
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
|
||||
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
|
||||
|
||||
@slow
|
||||
def test_rag_sequence_generate_batch(self):
|
||||
# IMPORTAN: This test fails on GPU, but is fine on CPU -> beam search is very sensible
|
||||
rag_config = self.get_rag_config()
|
||||
rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
|
||||
rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
|
||||
"facebook/dpr-question_encoder-single-nq-base"
|
||||
)
|
||||
rag_retriever = RagRetriever(
|
||||
rag_config,
|
||||
question_encoder_tokenizer=rag_question_encoder_tokenizer,
|
||||
generator_tokenizer=rag_decoder_tokenizer,
|
||||
)
|
||||
|
||||
rag_sequence = self.sequence_model
|
||||
rag_sequence.set_retriever(rag_retriever)
|
||||
|
||||
questions = [
|
||||
"who sings does he love me with reba",
|
||||
"how many pages is invisible man by ralph ellison",
|
||||
]
|
||||
input_ids = rag_question_encoder_tokenizer.batch_encode_plus(
|
||||
questions,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
).input_ids
|
||||
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
output_ids = rag_sequence.generate(
|
||||
input_ids,
|
||||
decoder_start_token_id=rag_sequence.generator.config.decoder_start_token_id,
|
||||
num_beams=4,
|
||||
num_return_sequences=1,
|
||||
max_length=10,
|
||||
)
|
||||
|
||||
# sequence generate test
|
||||
output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
||||
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
|
||||
|
||||
# Expected outputs as given by model at integration time.
|
||||
EXPECTED_OUTPUT_TEXT_1 = '"I Know Him So Well"'
|
||||
EXPECTED_OUTPUT_TEXT_2 = '"Howl" chronicles the'
|
||||
|
||||
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
|
||||
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
|
||||
|
||||
@slow
|
||||
def test_rag_sequence_generate_beam(self):
|
||||
rag_config = self.get_rag_config()
|
||||
rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
|
||||
rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
|
||||
"facebook/dpr-question_encoder-single-nq-base"
|
||||
)
|
||||
rag_retriever = RagRetriever(
|
||||
rag_config,
|
||||
question_encoder_tokenizer=rag_question_encoder_tokenizer,
|
||||
generator_tokenizer=rag_decoder_tokenizer,
|
||||
)
|
||||
|
||||
rag_token = self.sequence_model
|
||||
rag_token.set_retriever(rag_retriever)
|
||||
|
||||
input_ids = rag_question_encoder_tokenizer(
|
||||
"who sings does he love me with reba", return_tensors="pt"
|
||||
).input_ids
|
||||
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
output_ids = rag_token.generate(
|
||||
input_ids,
|
||||
decoder_start_token_id=rag_token.generator.config.decoder_start_token_id,
|
||||
num_beams=2,
|
||||
num_return_sequences=2,
|
||||
)
|
||||
# sequence generate test
|
||||
output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
||||
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
|
||||
|
||||
# Expected outputs as given by model at integration time.
|
||||
EXPECTED_OUTPUT_TEXT_1 = """ ABBA / small label like Playboy Records did not have the distribution resources to meet the demand for the single from retailers and radio programmers. The foursome decided to record their first album together in late 1972, and sessions began on 26 September 1972. The women shared lead vocals on "Nina, Pretty Ballerina" that day."""
|
||||
EXPECTED_OUTPUT_TEXT_2 = """ ABBA / small label like Playboy Records did not have the distribution resources to meet the demand for the single from retailers and radio programmers. The foursome decided to record their first album together in late 1972, and sessions began on 26 September 1972. The women shared lead vocals on "Nina, Pretty Ballerina" (a top ten hit in Austria)"""
|
||||
|
||||
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
|
||||
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_retrieval
|
||||
class RagModelSaveLoadTests(unittest.TestCase):
|
||||
def get_rag_config(self):
|
||||
question_encoder_config = AutoConfig.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
|
||||
generator_config = AutoConfig.from_pretrained("facebook/bart-large-cnn")
|
||||
return RagConfig.from_question_encoder_generator_configs(
|
||||
question_encoder_config,
|
||||
generator_config,
|
||||
bos_token_id=0,
|
||||
decoder_start_token_id=2,
|
||||
eos_token_id=2,
|
||||
is_encoder_decoder=True,
|
||||
pad_token_id=1,
|
||||
vocab_size=50264,
|
||||
title_sep=" / ",
|
||||
doc_sep=" // ",
|
||||
n_docs=5,
|
||||
max_combined_length=300,
|
||||
dataset="wiki_dpr",
|
||||
dataset_split="train",
|
||||
index_name="exact",
|
||||
index_path=None,
|
||||
use_dummy_dataset=True,
|
||||
retrieval_vector_size=768,
|
||||
retrieval_batch_size=8,
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_rag_sequence_from_pretrained(self):
|
||||
rag_config = self.get_rag_config()
|
||||
rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
|
||||
rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
|
||||
"facebook/dpr-question_encoder-single-nq-base"
|
||||
)
|
||||
rag_retriever = RagRetriever(
|
||||
rag_config,
|
||||
question_encoder_tokenizer=rag_question_encoder_tokenizer,
|
||||
generator_tokenizer=rag_decoder_tokenizer,
|
||||
)
|
||||
|
||||
input_ids = rag_question_encoder_tokenizer(
|
||||
"who sings does he love me with reba", return_tensors="pt"
|
||||
).input_ids
|
||||
decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="pt").input_ids
|
||||
|
||||
input_ids = input_ids.to(torch_device)
|
||||
decoder_input_ids = decoder_input_ids.to(torch_device)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
rag_sequence = RagSequenceForGeneration.from_pretrained_question_encoder_generator(
|
||||
"facebook/dpr-question_encoder-single-nq-base",
|
||||
"facebook/bart-large-cnn",
|
||||
retriever=rag_retriever,
|
||||
config=rag_config,
|
||||
).to(torch_device)
|
||||
# check that the from pretrained methods work
|
||||
rag_sequence.save_pretrained(tmp_dirname)
|
||||
rag_sequence.from_pretrained(tmp_dirname, retriever=rag_retriever)
|
||||
rag_sequence.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
output = rag_sequence(
|
||||
input_ids,
|
||||
labels=decoder_input_ids,
|
||||
)
|
||||
|
||||
loss_pretrained = output.loss
|
||||
del rag_sequence
|
||||
|
||||
question_encoder = AutoModel.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
|
||||
generator = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
|
||||
rag_sequence = RagSequenceForGeneration(
|
||||
config=rag_config, question_encoder=question_encoder, generator=generator, retriever=rag_retriever
|
||||
)
|
||||
rag_sequence.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
output = rag_sequence(
|
||||
input_ids,
|
||||
labels=decoder_input_ids,
|
||||
)
|
||||
|
||||
loss_init = output.loss
|
||||
|
||||
self.assertAlmostEqual(loss_pretrained.item(), loss_init.item(), places=4)
|
||||
|
||||
@slow
|
||||
def test_rag_token_from_pretrained(self):
|
||||
rag_config = self.get_rag_config()
|
||||
rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
|
||||
rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
|
||||
"facebook/dpr-question_encoder-single-nq-base"
|
||||
)
|
||||
rag_retriever = RagRetriever(
|
||||
rag_config,
|
||||
question_encoder_tokenizer=rag_question_encoder_tokenizer,
|
||||
generator_tokenizer=rag_decoder_tokenizer,
|
||||
)
|
||||
|
||||
input_ids = rag_question_encoder_tokenizer(
|
||||
"who sings does he love me with reba", return_tensors="pt"
|
||||
).input_ids
|
||||
decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="pt").input_ids
|
||||
|
||||
input_ids = input_ids.to(torch_device)
|
||||
decoder_input_ids = decoder_input_ids.to(torch_device)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
rag_token = RagTokenForGeneration.from_pretrained_question_encoder_generator(
|
||||
"facebook/dpr-question_encoder-single-nq-base",
|
||||
"facebook/bart-large-cnn",
|
||||
retriever=rag_retriever,
|
||||
config=rag_config,
|
||||
).to(torch_device)
|
||||
# check that the from pretrained methods work
|
||||
rag_token.save_pretrained(tmp_dirname)
|
||||
rag_token.from_pretrained(tmp_dirname, retriever=rag_retriever)
|
||||
rag_token.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
output = rag_token(
|
||||
input_ids,
|
||||
labels=decoder_input_ids,
|
||||
)
|
||||
|
||||
loss_pretrained = output.loss
|
||||
del rag_token
|
||||
|
||||
question_encoder = AutoModel.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
|
||||
generator = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
|
||||
rag_token = RagTokenForGeneration(
|
||||
config=rag_config, question_encoder=question_encoder, generator=generator, retriever=rag_retriever
|
||||
)
|
||||
rag_token.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
output = rag_token(
|
||||
input_ids,
|
||||
labels=decoder_input_ids,
|
||||
)
|
||||
|
||||
loss_init = output.loss
|
||||
|
||||
self.assertAlmostEqual(loss_pretrained.item(), loss_init.item(), places=4)
|
|
@ -35,28 +35,53 @@ if is_torch_available():
|
|||
|
||||
|
||||
class T5ModelTester:
|
||||
def __init__(self, parent):
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
vocab_size=99,
|
||||
n_positions=14,
|
||||
batch_size=13,
|
||||
encoder_seq_length=7,
|
||||
decoder_seq_length=9,
|
||||
# For common tests
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_attention_mask=True,
|
||||
use_labels=True,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
d_ff=37,
|
||||
relative_attention_num_buckets=8,
|
||||
dropout_rate=0.1,
|
||||
initializer_factor=0.002,
|
||||
eos_token_id=1,
|
||||
pad_token_id=0,
|
||||
decoder_start_token_id=0,
|
||||
scope=None,
|
||||
):
|
||||
|
||||
self.parent = parent
|
||||
self.batch_size = 13
|
||||
self.encoder_seq_length = 7
|
||||
self.decoder_seq_length = 9
|
||||
self.batch_size = batch_size
|
||||
self.encoder_seq_length = encoder_seq_length
|
||||
self.decoder_seq_length = decoder_seq_length
|
||||
# For common tests
|
||||
self.seq_length = self.decoder_seq_length
|
||||
self.is_training = True
|
||||
self.use_attention_mask = True
|
||||
self.use_labels = True
|
||||
self.vocab_size = 99
|
||||
self.n_positions = 14
|
||||
self.hidden_size = 32
|
||||
self.num_hidden_layers = 5
|
||||
self.num_attention_heads = 4
|
||||
self.d_ff = 37
|
||||
self.relative_attention_num_buckets = 8
|
||||
self.dropout_rate = 0.1
|
||||
self.initializer_factor = 0.002
|
||||
self.eos_token_id = 1
|
||||
self.pad_token_id = 0
|
||||
self.decoder_start_token_id = 0
|
||||
self.is_training = is_training
|
||||
self.use_attention_mask = use_attention_mask
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.n_positions = n_positions
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.d_ff = d_ff
|
||||
self.relative_attention_num_buckets = relative_attention_num_buckets
|
||||
self.dropout_rate = dropout_rate
|
||||
self.initializer_factor = initializer_factor
|
||||
self.eos_token_id = eos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.decoder_start_token_id = decoder_start_token_id
|
||||
self.scope = None
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
|
|
|
@ -0,0 +1,223 @@
|
|||
import json
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
import tempfile
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
from datasets import Dataset
|
||||
|
||||
import faiss
|
||||
from transformers.configuration_bart import BartConfig
|
||||
from transformers.configuration_dpr import DPRConfig
|
||||
from transformers.configuration_rag import RagConfig
|
||||
from transformers.retrieval_rag import RagRetriever
|
||||
from transformers.testing_utils import require_datasets, require_faiss, require_torch
|
||||
from transformers.tokenization_bart import BartTokenizer
|
||||
from transformers.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
|
||||
from transformers.tokenization_dpr import DPRQuestionEncoderTokenizer
|
||||
from transformers.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
|
||||
|
||||
|
||||
@require_faiss
|
||||
@require_datasets
|
||||
class RagRetrieverTest(TestCase):
|
||||
def setUp(self):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
self.retrieval_vector_size = 8
|
||||
|
||||
# DPR tok
|
||||
vocab_tokens = [
|
||||
"[UNK]",
|
||||
"[CLS]",
|
||||
"[SEP]",
|
||||
"[PAD]",
|
||||
"[MASK]",
|
||||
"want",
|
||||
"##want",
|
||||
"##ed",
|
||||
"wa",
|
||||
"un",
|
||||
"runn",
|
||||
"##ing",
|
||||
",",
|
||||
"low",
|
||||
"lowest",
|
||||
]
|
||||
dpr_tokenizer_path = os.path.join(self.tmpdirname, "dpr_tokenizer")
|
||||
os.makedirs(dpr_tokenizer_path, exist_ok=True)
|
||||
self.vocab_file = os.path.join(dpr_tokenizer_path, DPR_VOCAB_FILES_NAMES["vocab_file"])
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
|
||||
# BART tok
|
||||
vocab = [
|
||||
"l",
|
||||
"o",
|
||||
"w",
|
||||
"e",
|
||||
"r",
|
||||
"s",
|
||||
"t",
|
||||
"i",
|
||||
"d",
|
||||
"n",
|
||||
"\u0120",
|
||||
"\u0120l",
|
||||
"\u0120n",
|
||||
"\u0120lo",
|
||||
"\u0120low",
|
||||
"er",
|
||||
"\u0120lowest",
|
||||
"\u0120newer",
|
||||
"\u0120wider",
|
||||
"<unk>",
|
||||
]
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
|
||||
self.special_tokens_map = {"unk_token": "<unk>"}
|
||||
|
||||
bart_tokenizer_path = os.path.join(self.tmpdirname, "bart_tokenizer")
|
||||
os.makedirs(bart_tokenizer_path, exist_ok=True)
|
||||
self.vocab_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["vocab_file"])
|
||||
self.merges_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["merges_file"])
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as fp:
|
||||
fp.write(json.dumps(vocab_tokens) + "\n")
|
||||
with open(self.merges_file, "w", encoding="utf-8") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
def get_dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer:
|
||||
return DPRQuestionEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer"))
|
||||
|
||||
def get_bart_tokenizer(self) -> BartTokenizer:
|
||||
return BartTokenizer.from_pretrained(os.path.join(self.tmpdirname, "bart_tokenizer"))
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def get_dummy_hf_index_retriever(self):
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"id": ["0", "1"],
|
||||
"text": ["foo", "bar"],
|
||||
"title": ["Foo", "Bar"],
|
||||
"embeddings": [np.ones(self.retrieval_vector_size), 2 * np.ones(self.retrieval_vector_size)],
|
||||
}
|
||||
)
|
||||
dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT)
|
||||
config = RagConfig(
|
||||
retrieval_vector_size=self.retrieval_vector_size,
|
||||
question_encoder=DPRConfig().to_dict(),
|
||||
generator=BartConfig().to_dict(),
|
||||
)
|
||||
with patch("transformers.retrieval_rag.load_dataset") as mock_load_dataset:
|
||||
mock_load_dataset.return_value = dataset
|
||||
retriever = RagRetriever(
|
||||
config,
|
||||
question_encoder_tokenizer=self.get_dpr_tokenizer(),
|
||||
generator_tokenizer=self.get_bart_tokenizer(),
|
||||
)
|
||||
return retriever
|
||||
|
||||
def get_dummy_legacy_index_retriever(self):
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"id": ["0", "1"],
|
||||
"text": ["foo", "bar"],
|
||||
"title": ["Foo", "Bar"],
|
||||
"embeddings": [np.ones(self.retrieval_vector_size + 1), 2 * np.ones(self.retrieval_vector_size + 1)],
|
||||
}
|
||||
)
|
||||
dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT)
|
||||
|
||||
index_file_name = os.path.join(self.tmpdirname, "hf_bert_base.hnswSQ8_correct_phi_128.c_index")
|
||||
dataset.save_faiss_index("embeddings", index_file_name + ".index.dpr")
|
||||
pickle.dump(dataset["id"], open(index_file_name + ".index_meta.dpr", "wb"))
|
||||
|
||||
passages_file_name = os.path.join(self.tmpdirname, "psgs_w100.tsv.pkl")
|
||||
passages = {sample["id"]: [sample["text"], sample["title"]] for sample in dataset}
|
||||
pickle.dump(passages, open(passages_file_name, "wb"))
|
||||
|
||||
config = RagConfig(
|
||||
retrieval_vector_size=self.retrieval_vector_size,
|
||||
question_encoder=DPRConfig().to_dict(),
|
||||
generator=BartConfig().to_dict(),
|
||||
index_name="legacy",
|
||||
index_path=self.tmpdirname,
|
||||
passages_path=self.tmpdirname,
|
||||
)
|
||||
retriever = RagRetriever(
|
||||
config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer()
|
||||
)
|
||||
return retriever
|
||||
|
||||
def test_hf_index_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_hf_index_retriever()
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
|
||||
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
|
||||
self.assertEqual(len(doc_dicts), 2)
|
||||
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
|
||||
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
|
||||
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
|
||||
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
||||
self.assertListEqual(list(doc_ids), [1, 0])
|
||||
|
||||
def test_legacy_index_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_legacy_index_retriever()
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
|
||||
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
|
||||
self.assertEqual(len(doc_dicts), 2)
|
||||
self.assertEqual(sorted(doc_dicts[0]), ["text", "title"])
|
||||
self.assertEqual(len(doc_dicts[0]["text"]), n_docs)
|
||||
self.assertEqual(doc_dicts[0]["text"][0], "bar") # max inner product is reached with second doc
|
||||
self.assertEqual(doc_dicts[1]["text"][0], "foo") # max inner product is reached with first doc
|
||||
self.assertListEqual(list(doc_ids), [1, 0])
|
||||
|
||||
@require_torch
|
||||
def test_hf_index_retriever_call(self):
|
||||
import torch
|
||||
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_hf_index_retriever()
|
||||
question_input_ids = [[5, 7], [10, 11]]
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
out = retriever(question_input_ids, hidden_states, prefix=retriever.config.generator.prefix, n_docs=n_docs)
|
||||
context_input_ids, context_attention_mask, retrieved_doc_embeds = (
|
||||
out["context_input_ids"],
|
||||
out["context_attention_mask"],
|
||||
out["retrieved_doc_embeds"],
|
||||
)
|
||||
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
|
||||
self.assertIsInstance(context_input_ids, list)
|
||||
self.assertIsInstance(context_attention_mask, list)
|
||||
self.assertIsInstance(retrieved_doc_embeds, np.ndarray)
|
||||
|
||||
out = retriever(
|
||||
question_input_ids,
|
||||
hidden_states,
|
||||
prefix=retriever.config.generator.prefix,
|
||||
n_docs=n_docs,
|
||||
return_tensors="pt",
|
||||
)
|
||||
context_input_ids, context_attention_mask, retrieved_doc_embeds, doc_ids = ( # noqa: F841
|
||||
out["context_input_ids"],
|
||||
out["context_attention_mask"],
|
||||
out["retrieved_doc_embeds"],
|
||||
out["doc_ids"],
|
||||
)
|
||||
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
|
||||
self.assertIsInstance(context_input_ids, torch.Tensor)
|
||||
self.assertIsInstance(context_attention_mask, torch.Tensor)
|
||||
self.assertIsInstance(retrieved_doc_embeds, torch.Tensor)
|
|
@ -0,0 +1,110 @@
|
|||
import json
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from unittest import TestCase
|
||||
|
||||
from transformers.configuration_bart import BartConfig
|
||||
from transformers.configuration_dpr import DPRConfig
|
||||
from transformers.file_utils import is_datasets_available, is_faiss_available, is_torch_available
|
||||
from transformers.testing_utils import require_datasets, require_faiss, require_torch
|
||||
from transformers.tokenization_bart import BartTokenizer
|
||||
from transformers.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
|
||||
from transformers.tokenization_dpr import DPRQuestionEncoderTokenizer
|
||||
from transformers.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
|
||||
|
||||
|
||||
if is_torch_available() and is_datasets_available() and is_faiss_available():
|
||||
from transformers.configuration_rag import RagConfig
|
||||
from transformers.tokenization_rag import RagTokenizer
|
||||
|
||||
|
||||
@require_faiss
|
||||
@require_datasets
|
||||
@require_torch
|
||||
class RagTokenizerTest(TestCase):
|
||||
def setUp(self):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
self.retrieval_vector_size = 8
|
||||
|
||||
# DPR tok
|
||||
vocab_tokens = [
|
||||
"[UNK]",
|
||||
"[CLS]",
|
||||
"[SEP]",
|
||||
"[PAD]",
|
||||
"[MASK]",
|
||||
"want",
|
||||
"##want",
|
||||
"##ed",
|
||||
"wa",
|
||||
"un",
|
||||
"runn",
|
||||
"##ing",
|
||||
",",
|
||||
"low",
|
||||
"lowest",
|
||||
]
|
||||
dpr_tokenizer_path = os.path.join(self.tmpdirname, "dpr_tokenizer")
|
||||
os.makedirs(dpr_tokenizer_path, exist_ok=True)
|
||||
self.vocab_file = os.path.join(dpr_tokenizer_path, DPR_VOCAB_FILES_NAMES["vocab_file"])
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
|
||||
# BART tok
|
||||
vocab = [
|
||||
"l",
|
||||
"o",
|
||||
"w",
|
||||
"e",
|
||||
"r",
|
||||
"s",
|
||||
"t",
|
||||
"i",
|
||||
"d",
|
||||
"n",
|
||||
"\u0120",
|
||||
"\u0120l",
|
||||
"\u0120n",
|
||||
"\u0120lo",
|
||||
"\u0120low",
|
||||
"er",
|
||||
"\u0120lowest",
|
||||
"\u0120newer",
|
||||
"\u0120wider",
|
||||
"<unk>",
|
||||
]
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
|
||||
self.special_tokens_map = {"unk_token": "<unk>"}
|
||||
|
||||
bart_tokenizer_path = os.path.join(self.tmpdirname, "bart_tokenizer")
|
||||
os.makedirs(bart_tokenizer_path, exist_ok=True)
|
||||
self.vocab_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["vocab_file"])
|
||||
self.merges_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["merges_file"])
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as fp:
|
||||
fp.write(json.dumps(vocab_tokens) + "\n")
|
||||
with open(self.merges_file, "w", encoding="utf-8") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
def get_dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer:
|
||||
return DPRQuestionEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer"))
|
||||
|
||||
def get_bart_tokenizer(self) -> BartTokenizer:
|
||||
return BartTokenizer.from_pretrained(os.path.join(self.tmpdirname, "bart_tokenizer"))
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def test_save_load_pretrained_with_saved_config(self):
|
||||
|
||||
save_dir = os.path.join(self.tmpdirname, "rag_tokenizer")
|
||||
rag_config = RagConfig(question_encoder=DPRConfig().to_dict(), generator=BartConfig().to_dict())
|
||||
rag_tokenizer = RagTokenizer(question_encoder=self.get_dpr_tokenizer(), generator=self.get_bart_tokenizer())
|
||||
rag_config.save_pretrained(save_dir)
|
||||
rag_tokenizer.save_pretrained(save_dir)
|
||||
new_rag_tokenizer = RagTokenizer.from_pretrained(save_dir, config=rag_config)
|
||||
self.assertIsInstance(new_rag_tokenizer.question_encoder, DPRQuestionEncoderTokenizer)
|
||||
self.assertEqual(new_rag_tokenizer.question_encoder.vocab, rag_tokenizer.question_encoder.vocab)
|
||||
self.assertIsInstance(new_rag_tokenizer.generator, BartTokenizer)
|
||||
self.assertEqual(new_rag_tokenizer.generator.encoder, rag_tokenizer.generator.encoder)
|
Загрузка…
Ссылка в новой задаче