* added rag WIP

* path fix

* Formatting / renaming prior to actual work

* added rag WIP

* path fix

* Formatting / renaming prior to actual work

* added rag WIP

* path fix

* Formatting / renaming prior to actual work

* added rag WIP

* Formatting / renaming prior to actual work

* First commit

* improve comments

* Retrieval evaluation scripts

* refactor to include modeling outputs + MPI retriever

* Fix rag-token model + refactor

* Various fixes + finetuning logic

* use_bos fix

* Retrieval refactor

* Finetuning refactoring and cleanup

* Add documentation and cleanup

* Remove set_up_rag_env.sh file

* Fix retrieval wit HF index

* Fix import errors

* Fix quality errors

* Refactor as per suggestions in https://github.com/huggingface/transformers/pull/6813#issuecomment-687208867

* fix quality

* Fix RAG Sequence generation

* minor cleanup plus initial tests

* fix test

* fix tests 2

* Comments fix

* post-merge fixes

* Improve readme + post-rebase refactor

* Extra dependencied for tests

* Fix tests

* Fix tests 2

* Refactor test requirements

* Fix tests 3

* Post-rebase refactor

* rename nlp->datasets

* RAG integration tests

* add tokenizer to slow integration test and allow retriever to run on cpu

* add tests; fix position ids warning

* change structure

* change structure

* add from encoder generator

* save working solution

* make all integration tests pass

* add RagTokenizer.save/from_pretrained and RagRetriever.save/from_pretrained

* don't save paths

* delete unnecessary imports

* pass config to AutoTokenizer.from_pretrained for Rag tokenizers

* init wiki_dpr only once

* hardcode legacy index and passages paths (todo: add the right urls)

* finalize config

* finalize retriver api and config api

* LegacyIndex index download refactor

* add dpr to autotokenizer

* make from pretrained more flexible

* fix ragfortokengeneration

* small name changes in tokenizer

* add labels to models

* change default index name

* add retrieval tests

* finish token generate

* align test with previous version and make all tests pass

* add tests

* finalize tests

* implement thoms suggestions

* add first version of test

* make first tests work

* make retriever platform agnostic

* naming

* style

* add legacy index URL

* docstrings + simple retrieval test for distributed

* clean model api

* add doc_ids to retriever's outputs

* fix retrieval tests

* finish model outputs

* finalize model api

* fix generate problem for rag

* fix generate for other modles

* fix some tests

* save intermediate

* set generate to default

* big refactor generate

* delete rag_api

* correct pip faiss install

* fix auto tokenization test

* fix faiss install

* fix test

* move the distributed logic to examples

* model page

* docs

* finish tests

* fix dependencies

* fix import in __init__

* Refactor eval_rag and finetune scripts

* start docstring

* add psutil to test

* fix tf test

* move require torch to top

* fix retrieval test

* align naming

* finish automodel

* fix repo consistency

* test ragtokenizer save/load

* add rag model output docs

* fix ragtokenizer save/load from pretrained

* fix tokenizer dir

* remove torch in retrieval

* fix docs

* fixe finetune scripts

* finish model docs

* finish docs

* remove auto model for now

* add require torch

* remove solved todos

* integrate sylvains suggestions

* sams comments

* correct mistake on purpose

* improve README

* Add generation test cases

* fix rag token

* clean token generate

* fix test

* add note to test

* fix attention mask

* add t5 test for rag

* Fix handling prefix in finetune.py

* don't overwrite index_name

Co-authored-by: Patrick Lewis <plewis@fb.com>
Co-authored-by: Aleksandra Piktus <piktus@devfair0141.h2.fair>
Co-authored-by: Aleksandra Piktus <piktus@learnfair5102.h2.fair>
Co-authored-by: Aleksandra Piktus <piktus@learnfair5067.h2.fair>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Quentin Lhoest <lhoest.q@gmail.com>
This commit is contained in:
Ola Piktus 2020-09-22 17:29:58 +01:00 коммит произвёл GitHub
Родитель 1ee2194fb6
Коммит c754c41c61
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
37 изменённых файлов: 5176 добавлений и 32 удалений

2
.gitignore поставляемый
Просмотреть файл

@ -11,6 +11,7 @@ __pycache__/
# tests and logs
tests/fixtures
logs/
lightning_logs/
# Distribution / packaging
.Python
@ -139,6 +140,7 @@ runs
/wandb
/examples/runs
/examples/**/*.args
/examples/rag/sweep
# data
/data

Просмотреть файл

@ -231,6 +231,7 @@ conversion utilities for the following models:
model_doc/lxmert
model_doc/bertgeneration
model_doc/layoutlm
model_doc/rag
internal/modeling_utils
internal/tokenization_utils
internal/pipelines_utils

Просмотреть файл

@ -0,0 +1,88 @@
RAG
----------------------------------------------------
Overview
~~~~~~~~~~~~~~~~~~~~~
Retrieval-augmented generation ("RAG") models combine the powers of pretrained dense retrieval (DPR) and Seq2Seq models.
RAG models retrieve docs, pass them to a seq2seq model, then marginalize to generate outputs.
The retriever and seq2seq modules are initialized from pretrained models, and fine-tuned jointly, allowing both retrieval and generation to adapt to downstream tasks.
It is based on the paper `Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks <https://arxiv.org/abs/2005.11401>`__ by Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela.
The abstract from the paper is the following:
*Large pre-trained language models have been shown to store factual knowledge
in their parameters, and achieve state-of-the-art results when fine-tuned on
downstream NLP tasks. However, their ability to access and precisely manipulate
knowledge is still limited, and hence on knowledge-intensive tasks, their
performance lags behind task-specific architectures. Additionally, providing
provenance for their decisions and updating their world knowledge remain open
research problems. Pre-trained models with a differentiable access mechanism to
explicit nonparametric memory can overcome this issue, but have so far been only
investigated for extractive downstream tasks. We explore a general-purpose
fine-tuning recipe for retrieval-augmented generation (RAG) — models which combine
pre-trained parametric and non-parametric memory for language generation. We
introduce RAG models where the parametric memory is a pre-trained seq2seq model and
the non-parametric memory is a dense vector index of Wikipedia, accessed with
a pre-trained neural retriever. We compare two RAG formulations, one which
conditions on the same retrieved passages across the whole generated sequence, the
other can use different passages per token. We fine-tune and evaluate our models
on a wide range of knowledge-intensive NLP tasks and set the state-of-the-art
on three open domain QA tasks, outperforming parametric seq2seq models and
task-specific retrieve-and-extract architectures. For language generation tasks, we
find that RAG models generate more specific, diverse and factual language than a
state-of-the-art parametric-only seq2seq baseline.*
RagConfig
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.RagConfig
:members:
RagTokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.RagTokenizer
:members:
Rag specific outputs
~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.modeling_rag.RetrievAugLMMarginOutput
:members:
.. autoclass:: transformers.modeling_rag.RetrievAugLMOutput
:members:
RAGRetriever
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.RagRetriever
:members:
RagModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.RagModel
:members: forward
RagSequenceForGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.RagSequenceForGeneration
:members: forward, generate
RagTokenForGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.RagTokenForGeneration
:members: forward, generate

Просмотреть файл

@ -654,7 +654,7 @@ DPR
<a href="https://huggingface.co/models?filter=dpr">
<img alt="Models" src="https://img.shields.io/badge/All_model_pages-dpr-blueviolet">
</a>
<a href="model_doc/ctrl.dpr">
<a href="model_doc/dpr.html">
<img alt="Doc" src="https://img.shields.io/badge/Model_documentation-dpr-blueviolet">
</a>
@ -672,6 +672,27 @@ DPR consists in three models:
DPR's pipeline (not implemented yet) uses a retrieval step to find the top k contexts given a certain question, and then it calls the reader with the question and the retrieved documents to get the answer.
RAG
----------------------------------------------
.. raw:: html
<a href="https://huggingface.co/models?filter=rag">
<img alt="Models" src="https://img.shields.io/badge/All_model_pages-rag-blueviolet">
</a>
<a href="model_doc/rag.html">
<img alt="Doc" src="https://img.shields.io/badge/Model_documentation-rag-blueviolet">
</a>
`Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks <https://arxiv.org/abs/2005.11401>`_,
Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela
Retrieval-augmented generation ("RAG") models combine the powers of pretrained dense retrieval (DPR) and Seq2Seq models.
RAG models retrieve docs, pass them to a seq2seq model, then marginalize to generate outputs.
The retriever and seq2seq modules are initialized from pretrained models, and fine-tuned jointly, allowing both retrieval and generation to adapt to downstream tasks.
The two models RAG-Token and RAG-Sequence are available for generation.
More technical aspects
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Просмотреть файл

@ -366,6 +366,8 @@ def generic_train(
if args.gpus > 1:
train_params["distributed_backend"] = "ddp"
train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
trainer = pl.Trainer.from_argparse_args(
args,
weights_summary=None,

Просмотреть файл

@ -1,10 +1,10 @@
import datasets
import faiss
import numpy as np
import streamlit as st
import torch
from elasticsearch import Elasticsearch
import faiss
import transformers
from eli5_utils import (
embed_questions_for_retrieval,

Просмотреть файл

@ -5,7 +5,6 @@ from random import choice, randint
from time import time
import datasets # noqa: F401
import faiss # noqa: F401
import numpy as np
import pandas as pd
import torch
@ -15,6 +14,7 @@ from elasticsearch.helpers import bulk, streaming_bulk # noqa: F401
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from tqdm import tqdm
import faiss # noqa: F401
from transformers import AdamW, AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer, get_linear_schedule_with_warmup

88
examples/rag/README.md Normal file
Просмотреть файл

@ -0,0 +1,88 @@
# Intro
RAG is a seq2seq model which encapsulates two core components: a question encoder and a generator.
During a forward pass, we encode the input with the question encoder and pass it
to the retriever to extract relevant context documents. The documents are then prepended to the input.
Such contextualized inputs is passed to the generator.
The question encoder can be any `autoencoding` model, preferably :obj:`~transformers.DPRQuestionEncoder`, and the generator can be any `seq2seq` model, preferably :obj:`~transformers.BartForConditionalGeneration`.
The model can be initialized with a :obj:`~transformers.RagRetriever` for end-to-end generation or used in combination with the outputs of a retriever in multiple steps - see examples for more details.
The model is compatible any `autoencoding` model as the ``question_encoder`` and any `seq2seq` model with language model head as the ``generator``.
The model has been tested with :class:`~transformers.DPRQuestionEncoder` as the ``question_encoder`` and :class:`~transformers.BartForConditionalGeneration` or :class:`~transformers.T5ForConditionalGeneration` as the ``generator``.
RAG models were released with the paper `Retrieval-Augmented Generation for
Knowledge-Intensive NLP Tasks <https://arxiv.org/abs/2005.11401>`_ by Patrick Lewis, Ethan Perez, Aleksandra Piktus et al.
# Finetuning
Our finetuning logic is based on scripts from [`examples/seq2seq`](https://github.com/huggingface/transformers/tree/master/examples/seq2seq).
Follow instructions there regarding data preprocessing. A sample finetuning command:
```
python examples/rag/finetune.py \
--data_dir $DATA_DIR \
--output_dir $OUTPUT_DIR \
--model_name_or_path $MODEL_NAME_OR_PATH \
--model_type rag_sequence \
--fp16 \
--gpus 8
```
# Evaluation
Apart from the parameters specifying the model to evaluate and some extra parameters, the evaluation script expects paths to two files:
- `evaluation_set` - a path to a file specifying the evaluation dataset, a single datapoint per line, e.g.
```who is the owner of reading football club```
- `gold_data_path` - a path to a file contaning ground truth answers for datapoints from the `evaluation_set`.
We expect the following formats of the gold data file:
- for e2e evaluation, we support two formats of the gold file:
- `qa` - where a single line in the following format: input [tab] output_list, e.g.:
```
who is the owner of reading football club ['Xiu Li Dai', 'Dai Yongge', 'Dai Xiuli', 'Yongge Dai']
```
- `ans` - where a single line of the gold file contains the expected output string, e.g.:
```
Xiu Li Dai
```
- for retrieval evaluation, we expect a tab-separated list of Wikipedia page titles constituting positive contexts for a given query, e.g. given a question `who sings does he love me with reba`, a line with ground truth retrieval data could look as follows:
```
Does He Love You Does He Love You Red Sandy Spika dress of Reba McEntire Greatest Hits Volume Two (Reba McEntire album) Shoot for the Moon (album)
```
## Retrieval evaluation
We demonstrate how to evaluate retrieval against DPR evaluation data. You can download respective files from links listed [here](https://github.com/facebookresearch/DPR/blob/master/data/download_data.py#L39-L45).
1. Download and unzip the gold data file. We use the `biencoder-nq-dev` from https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-dev.json.gz.
2. Parse the unziped file using the `parse_dpr_relevance_data.py`
```
python examples/rag/parse_dpr_relevance_data.py --src_path path/to/unziped/biencoder-nq-dev.json --evaluation_set path/to/output/biencoder-nq-dev.questions --gold_data_path path/to/output/biencoder-nq-dev.pages
```
3. Run evaluation:
```
python examples/rag/eval_rag.py \
--model_name_or_path $MODEL_NAME_OR_PATH \ # model name or path of the model we're evaluating
--model_type rag_sequence \ # RAG model type (rag_token or rag_sequence)
--evaluation_set path/to/output/biencoder-nq-dev.questions \ # an input dataset for evaluation
--gold_data_path path/to/output/biencoder-nq-dev.pages \ # a dataset containing ground truth answers for samples from the evaluation_set
--predictions_path path/to/retrieval_preds.tsv \ # name of file in which predictions will be stored
--eval_mode retrieval \ # indicates whether we're performing retrieval evaluation or e2e evaluation
--recalculate # if predictions_filename already exists, and this option is set - we regenerate the answers, otherwise we reuse the predicsion file to calculate metrics.
```
## End-to-end evaluation
```
python examples/rag/eval_rag.py \
--model_name_or_path $MODEL_NAME_OR_PATH \
--model_type rag_sequence \
--evaluation_set path/to/test.source \
--gold_data_path path/to/gold_data \
--predictions_path path/to/e2e_preds.txt \
--eval_mode e2e \ # indicates whether we're performing retrieval evaluation or e2e evaluation (default)
--n_docs 5 \ # You can experiment with retrieving different number of documents at evaluation time
--print_predictions
```

0
examples/rag/__init__.py Normal file
Просмотреть файл

30
examples/rag/callbacks.py Normal file
Просмотреть файл

@ -0,0 +1,30 @@
import logging
import os
from pytorch_lightning.callbacks import ModelCheckpoint
logger = logging.getLogger(__name__)
def get_checkpoint_callback(output_dir, metric):
"""Saves the best model by validation EM score."""
if metric == "rouge2":
exp = "{val_avg_rouge2:.4f}-{step_count}"
elif metric == "bleu":
exp = "{val_avg_bleu:.4f}-{step_count}"
elif metric == "em":
exp = "{val_avg_em:.4f}-{step_count}"
else:
raise NotImplementedError(
f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this function."
)
checkpoint_callback = ModelCheckpoint(
filepath=os.path.join(output_dir, exp),
monitor=f"val_{metric}",
mode="max",
save_top_k=3,
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
)
return checkpoint_callback

Просмотреть файл

@ -0,0 +1,135 @@
import logging
import os
from typing import List, Tuple
import numpy as np
import psutil
import torch
import torch.distributed as dist
from transformers import RagRetriever
logger = logging.getLogger(__name__)
class RagPyTorchDistributedRetriever(RagRetriever):
"""
A distributed retriever built on top of the ``torch.distributed`` communication package. During training all workers
initalize their own instance of the retriever, however, only the main worker loads the index into memory. The index is stored
in cpu memory. The index will also work well in a non-distributed setup.
Args:
config (:class:`~transformers.RagConfig`):
The configuration of the RAG model this Retriever is used with. Contains parameters indicating which ``Index`` to build.
question_encoder_tokenizer (:class:`~transformers.PretrainedTokenizer`):
The tokenizer that was used to tokenize the question.
It is used to decode the question and then use the generator_tokenizer.
generator_tokenizer (:class:`~transformers.PretrainedTokenizer`):
The tokenizer used for the generator part of the RagModel.
"""
_init_retrieval = False
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer):
super().__init__(
config, question_encoder_tokenizer=question_encoder_tokenizer, generator_tokenizer=generator_tokenizer
)
self.process_group = None
def init_retrieval(self, distributed_port: int):
"""
Retriever initalization function, needs to be called from the training process. The function sets some common parameters
and environment variables. On top of that, (only) the main process in the process group loads the index into memory.
Args:
distributed_port (:obj:`int`):
The port on which the main communication of the training run is carried out. We set the port for retrieval-related
communication as ``distributed_port + 1``.
"""
logger.info("initializing retrieval")
# initializing a separate process group for retrievel as the default
# nccl backend doesn't support gather/scatter operations while gloo
# is too slow to replace nccl for the core gpu communication
if dist.is_initialized():
logger.info("dist initialized")
# needs to be set manually
os.environ["GLOO_SOCKET_IFNAME"] = self._infer_socket_ifname()
# avoid clash with the NCCL port
os.environ["MASTER_PORT"] = str(distributed_port + 1)
self.process_group = dist.new_group(ranks=None, backend="gloo")
# initialize retriever only on the main worker
if not dist.is_initialized() or self._is_main():
logger.info("dist not initialized / main")
self.index.init_index()
# all processes wait untill the retriever is initialized by the main process
if dist.is_initialized():
torch.distributed.barrier(group=self.process_group)
def _is_main(self):
return dist.get_rank(group=self.process_group) == 0
def _scattered(self, scatter_list, target_shape, target_type=torch.float32):
target_tensor = torch.empty(target_shape, dtype=target_type)
dist.scatter(target_tensor, src=0, scatter_list=scatter_list, group=self.process_group)
return target_tensor
def _infer_socket_ifname(self):
addrs = psutil.net_if_addrs()
# a hacky way to deal with varying network interface names
ifname = next((addr for addr in addrs if addr.startswith("e")), None)
return ifname
def retrieve(self, question_hidden_states: np.ndarray, n_docs: int) -> Tuple[np.ndarray, List[dict]]:
"""
Retrieves documents for specified ``question_hidden_states``. The main process, which has the access to the index stored in memory, gathers queries
from all the processes in the main training process group, performs the retrieval and scatters back the results.
Args:
question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`):
A batch of query vectors to retrieve with.
n_docs (:obj:`int`):
The number of docs retrieved per query.
Ouput:
retrieved_doc_embeds (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs, dim)`
The retrieval embeddings of the retrieved docs per query.
doc_ids (:obj:`np.ndarray` of shape :obj:`batch_size, n_docs`)
The ids of the documents in the index
doc_dicts (:obj:`List[dict]`):
The retrieved_doc_embeds examples per query.
"""
# single GPU training
if not dist.is_initialized():
doc_ids, retrieved_doc_embeds = self._main_retrieve(question_hidden_states, n_docs)
return retrieved_doc_embeds, doc_ids, self.index.get_doc_dicts(doc_ids)
# distributed training
world_size = dist.get_world_size(group=self.process_group)
# gather logic
gather_list = None
if self._is_main():
gather_list = [torch.empty(question_hidden_states.shape, dtype=torch.float32) for _ in range(world_size)]
dist.gather(torch.tensor(question_hidden_states), dst=0, gather_list=gather_list, group=self.process_group)
# scatter logic
n_queries = question_hidden_states.shape[0]
scatter_ids = []
scatter_vectors = []
if self._is_main():
assert len(gather_list) == world_size
ids, vectors = self._main_retrieve(torch.cat(gather_list).numpy(), n_docs)
ids, vectors = torch.tensor(ids), torch.tensor(vectors)
scatter_ids = self._chunk_tensor(ids, n_queries)
scatter_vectors = self._chunk_tensor(vectors, n_queries)
doc_ids = self._scattered(scatter_ids, [n_queries, n_docs], target_type=torch.int64)
retrieved_doc_embeds = self._scattered(scatter_vectors, [n_queries, n_docs, question_hidden_states.shape[1]])
return retrieved_doc_embeds.numpy(), doc_ids.numpy(), self.index.get_doc_dicts(doc_ids)

310
examples/rag/eval_rag.py Normal file
Просмотреть файл

@ -0,0 +1,310 @@
""" Evaluation script for RAG models."""
import argparse
import ast
import logging
import os
import sys
import pandas as pd
import torch
from tqdm import tqdm
from transformers import BartForConditionalGeneration, RagRetriever, RagSequenceForGeneration, RagTokenForGeneration
from transformers import logging as transformers_logging
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # isort:skip
from examples.rag.utils import exact_match_score, f1_score # noqa: E402 # isort:skip
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
transformers_logging.set_verbosity_info()
def infer_model_type(model_name_or_path):
if "token" in model_name_or_path:
return "rag_token"
if "sequence" in model_name_or_path:
return "rag_sequence"
if "bart" in model_name_or_path:
return "bart"
return None
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
return max(metric_fn(prediction, gt) for gt in ground_truths)
def get_scores(args, preds_path, gold_data_path):
hypos = [line.strip() for line in open(preds_path, "r").readlines()]
answers = []
if args.gold_data_mode == "qa":
data = pd.read_csv(gold_data_path, sep="\t", header=None)
for answer_list in data[1]:
ground_truths = ast.literal_eval(answer_list)
answers.append(ground_truths)
else:
references = [line.strip() for line in open(gold_data_path, "r").readlines()]
answers = [[reference] for reference in references]
f1 = em = total = 0
for prediction, ground_truths in zip(hypos, answers):
total += 1
em += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths)
f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths)
em = 100.0 * em / total
f1 = 100.0 * f1 / total
logger.info(f"F1: {f1:.2f}")
logger.info(f"EM: {em:.2f}")
def get_precision_at_k(args, preds_path, gold_data_path):
k = args.k
hypos = [line.strip() for line in open(preds_path, "r").readlines()]
references = [line.strip() for line in open(gold_data_path, "r").readlines()]
em = total = 0
for hypo, reference in zip(hypos, references):
hypo_provenance = set(hypo.split("\t")[:k])
ref_provenance = set(reference.split("\t")[1 : (k + 1)])
total += 1
em += len(hypo_provenance & ref_provenance) / k
em = 100.0 * em / total
logger.info(f"Precision@{k}: {em: .2f}")
def evaluate_batch_retrieval(args, rag_model, questions):
def strip_title(title):
if title.startswith('"'):
title = title[1:]
if title.endswith('"'):
title = title[:-1]
return title
retriever_input_ids = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus(
questions,
return_tensors="pt",
padding=True,
truncation=True,
)["input_ids"].to(args.device)
question_enc_outputs = rag_model.rag.question_encoder(retriever_input_ids, return_dict=True)
question_enc_pool_output = question_enc_outputs.pooler_output
result = rag_model.retriever(
retriever_input_ids,
question_enc_pool_output.cpu().detach().to(torch.float32).numpy(),
prefix=rag_model.rag.generator.config.prefix,
n_docs=rag_model.config.n_docs,
return_tensors="pt",
)
all_docs = rag_model.retriever.index.get_doc_dicts(result.doc_ids)
provenance_strings = []
for docs in all_docs:
provenance = [strip_title(title) for title in docs["title"]]
provenance_strings.append("\t".join(provenance))
return provenance_strings
def evaluate_batch_e2e(args, rag_model, questions):
with torch.no_grad():
input_ids = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus(
questions, return_tensors="pt", padding=True, truncation=True
)["input_ids"].to(args.device)
outputs = rag_model.generate( # rag_model overwrites generate
input_ids,
num_beams=args.num_beams,
min_length=args.min_length,
max_length=args.max_length,
early_stopping=False,
num_return_sequences=1,
bad_words_ids=[[0, 0]], # BART likes to repeat BOS tokens, dont allow it to generate more than one
clean_up_tokenization=True,
print_docs=args.print_docs,
)
answers = rag_model.retriever.generator_tokenizer.batch_decode(outputs, skip_special_tokens=True)
if args.print_predictions:
for q, a in zip(questions, answers):
logger.info("Q: {} - A: {}".format(q, a))
return answers
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_type",
choices=["rag_sequence", "rag_token", "bart"],
type=str,
help="RAG model type: rag_sequence, rag_token or bart, if none specified, the type is inferred from the model_name_or_path",
)
parser.add_argument(
"--index_name",
default=None,
choices=["hf", "legacy"],
type=str,
help="RAG model retriever type",
)
parser.add_argument(
"--index_path",
default=None,
type=str,
help="Path to the retrieval index",
)
parser.add_argument("--n_docs", default=5, type=int, help="Number of retrieved docs")
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
required=True,
help="Path to pretrained checkpoints or model identifier from huggingface.co/models",
)
parser.add_argument(
"--eval_mode",
choices=["e2e", "retrieval"],
default="e2e",
type=str,
help="Evaluation mode, e2e calculates exact match and F1 of the downstream task, retrieval calulates precision@k.",
)
parser.add_argument("--k", default=1, type=int, help="k for the precision@k calculation")
parser.add_argument(
"--evaluation_set",
default=None,
type=str,
required=True,
help="Path to a file containing evaluation samples",
)
parser.add_argument(
"--gold_data_path",
default=None,
type=str,
required=True,
help="Path to a tab-separated file with gold samples",
)
parser.add_argument(
"--gold_data_mode",
default="qa",
type=str,
choices=["qa", "ans"],
help="Format of the gold data file"
"qa - a single line in the following format: question [tab] answer_list"
"ans - a single line of the gold file contains the expected answer string",
)
parser.add_argument(
"--predictions_path",
type=str,
default="predictions.txt",
help="Name of the predictions file, to be stored in the checkpoints directry",
)
parser.add_argument(
"--eval_all_checkpoints",
action="store_true",
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
)
parser.add_argument(
"--eval_batch_size",
default=8,
type=int,
help="Batch size per GPU/CPU for evaluation.",
)
parser.add_argument(
"--recalculate",
help="Recalculate predictions even if the prediction file exists",
action="store_true",
)
parser.add_argument(
"--num_beams",
default=4,
type=int,
help="Number of beams to be used when generating answers",
)
parser.add_argument("--min_length", default=1, type=int, help="Min length of the generated answers")
parser.add_argument("--max_length", default=50, type=int, help="Max length of the generated answers")
parser.add_argument(
"--print_predictions",
action="store_true",
help="If True, prints predictions while evaluating.",
)
parser.add_argument(
"--print_docs",
action="store_true",
help="If True, prints docs retried while generating.",
)
args = parser.parse_args()
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
return args
def main(args):
model_kwargs = {}
if args.model_type is None:
args.model_type = infer_model_type(args.model_name_or_path)
assert args.model_type is not None
if args.model_type.startswith("rag"):
model_class = RagTokenForGeneration if args.model_type == "rag_token" else RagSequenceForGeneration
model_kwargs["n_docs"] = args.n_docs
if args.index_name is not None:
model_kwargs["index_name"] = args.index_name
if args.index_path is not None:
model_kwargs["index_path"] = args.index_path
else:
model_class = BartForConditionalGeneration
checkpoints = (
[f.path for f in os.scandir(args.model_name_or_path) if f.is_dir()]
if args.eval_all_checkpoints
else [args.model_name_or_path]
)
logger.info("Evaluate the following checkpoints: %s", checkpoints)
score_fn = get_scores if args.eval_mode == "e2e" else get_precision_at_k
evaluate_batch_fn = evaluate_batch_e2e if args.eval_mode == "e2e" else evaluate_batch_retrieval
for checkpoint in checkpoints:
if os.path.exists(args.predictions_path) and (not args.recalculate):
logger.info("Calculating metrics based on an existing predictions file: {}".format(args.predictions_path))
score_fn(args, args.predictions_path, args.gold_data_path)
continue
logger.info("***** Running evaluation for {} *****".format(checkpoint))
logger.info(" Batch size = %d", args.eval_batch_size)
logger.info(" Predictions will be stored under {}".format(args.predictions_path))
if args.model_type.startswith("rag"):
retriever = RagRetriever.from_pretrained(checkpoint, **model_kwargs)
model = model_class.from_pretrained(checkpoint, retriever=retriever, **model_kwargs)
model.retriever.init_retrieval()
else:
model = model_class.from_pretrained(checkpoint, **model_kwargs)
model.to(args.device)
with open(args.evaluation_set, "r") as eval_file, open(args.predictions_path, "w") as preds_file:
questions = []
for line in tqdm(eval_file):
questions.append(line.strip())
if len(questions) == args.eval_batch_size:
answers = evaluate_batch_fn(args, model, questions)
preds_file.write("\n".join(answers) + "\n")
preds_file.flush()
questions = []
if len(questions) > 0:
answers = evaluate_batch_fn(args, model, questions)
preds_file.write("\n".join(answers))
preds_file.flush()
score_fn(args, args.predictions_path, args.gold_data_path)
if __name__ == "__main__":
args = get_args()
main(args)

474
examples/rag/finetune.py Normal file
Просмотреть файл

@ -0,0 +1,474 @@
"""Finetuning script for RAG models. Adapted from examples.seq2seq.finetune.py"""
import argparse
import glob
import logging
import os
import sys
import time
import warnings
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Tuple
import numpy as np
import pytorch_lightning as pl
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from transformers import (
AutoConfig,
AutoTokenizer,
BartForConditionalGeneration,
RagConfig,
RagSequenceForGeneration,
RagTokenForGeneration,
RagTokenizer,
T5ForConditionalGeneration,
get_linear_schedule_with_warmup,
)
from transformers import logging as transformers_logging
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip
from examples.lightning_base import BaseTransformer, add_generic_args, generic_train # noqa: E402 # isort:skip
from examples.rag.callbacks import get_checkpoint_callback # noqa: E402 # isort:skip
from examples.rag.distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
from examples.rag.utils import ( # noqa: E402 # isort:skip
Seq2SeqDataset,
calculate_exact_match,
is_rag_model,
set_extra_model_params,
)
from examples.seq2seq.callbacks import Seq2SeqLoggingCallback, get_early_stopping_callback # noqa: E402 # isort:skip
from examples.seq2seq.utils import ( # noqa: E402 # isort:skip
flatten_list,
get_git_info,
lmap,
pickle_save,
save_git_info,
save_json,
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
transformers_logging.set_verbosity_info()
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
class GenerativeQAModule(BaseTransformer):
mode = "generative_qa"
loss_names = ["loss"]
metric_names = ["em"]
val_metric = "em"
def __init__(self, hparams, **kwargs):
# when loading from a pytorch lightning checkpoint, hparams are passed as dict
if isinstance(hparams, dict):
hparams = AttrDict(hparams)
if hparams.model_type == "rag_sequence":
self.model_class = RagSequenceForGeneration
elif hparams.model_type == "rag_token":
self.model_class = RagTokenForGeneration
elif hparams.model_type == "bart":
self.model_class = BartForConditionalGeneration
else:
self.model_class = T5ForConditionalGeneration
self.is_rag_model = is_rag_model(hparams.model_type)
config_class = RagConfig if self.is_rag_model else AutoConfig
config = config_class.from_pretrained(hparams.model_name_or_path)
# set extra_model_params for generator configs and load_model
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "attention_dropout", "dropout")
if self.is_rag_model:
if args.prefix is not None:
config.generator.prefix = args.prefix
config.label_smoothing = hparams.label_smoothing
hparams, config.generator = set_extra_model_params(extra_model_params, hparams, config.generator)
retriever = RagPyTorchDistributedRetriever.from_pretrained(hparams.model_name_or_path)
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config, retriever=retriever)
prefix = config.question_encoder.prefix
else:
if args.prefix is not None:
config.prefix = args.prefix
hparams, config = set_extra_model_params(extra_model_params, hparams, config)
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config)
prefix = config.prefix
tokenizer = (
RagTokenizer.from_pretrained(hparams.model_name_or_path)
if self.is_rag_model
else AutoTokenizer.from_pretrained(hparams.model_name_or_path)
)
super().__init__(hparams, config=config, tokenizer=tokenizer, model=model)
save_git_info(self.hparams.output_dir)
self.output_dir = Path(self.hparams.output_dir)
self.metrics_save_path = Path(self.output_dir) / "metrics.json"
self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
pickle_save(self.hparams, self.hparams_save_path)
self.step_count = 0
self.metrics = defaultdict(list)
self.dataset_kwargs: dict = dict(
data_dir=self.hparams.data_dir,
max_source_length=self.hparams.max_source_length,
prefix=prefix or "",
)
n_observations_per_split = {
"train": self.hparams.n_train,
"val": self.hparams.n_val,
"test": self.hparams.n_test,
}
self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}
self.target_lens = {
"train": self.hparams.max_target_length,
"val": self.hparams.val_max_target_length,
"test": self.hparams.test_max_target_length,
}
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
self.hparams.git_sha = get_git_info()["repo_sha"]
self.num_workers = hparams.num_workers
self.distributed_port = self.hparams.distributed_port
def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True):
logger.info("Custom init_ddp_connection.")
os.environ["MASTER_PORT"] = str(self.distributed_port)
super().init_ddp_connection(global_rank, world_size, is_slurm_managing_tasks)
if self.is_rag_model:
self.model.retriever.init_retrieval(self.distributed_port)
def forward(self, input_ids, **kwargs):
return self.model(input_ids, **kwargs)
def ids_to_clean_text(self, generated_ids: List[int]):
gen_text = self.tokenizer.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
return lmap(str.strip, gen_text)
def _step(self, batch: dict) -> Tuple:
source_ids, source_mask, target_ids = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
rag_kwargs = {}
if isinstance(self.model, T5ForConditionalGeneration):
decoder_input_ids = self.model._shift_right(target_ids)
lm_labels = target_ids
elif isinstance(self.model, BartForConditionalGeneration):
decoder_input_ids = target_ids[:, :-1].contiguous()
lm_labels = target_ids[:, 1:].clone()
else:
assert self.is_rag_model
generator = self.model.rag.generator
if isinstance(generator, T5ForConditionalGeneration):
decoder_start_token_id = generator.config.decoder_start_token_id
decoder_input_ids = (
torch.cat(
[torch.Tensor([[decoder_start_token_id]] * target_ids.shape[0]).to(target_ids), target_ids],
dim=1,
)
if target_ids.shape[0] < self.target_lens["train"]
else generator._shift_right(target_ids)
)
elif isinstance(generator, BartForConditionalGeneration):
decoder_input_ids = target_ids
lm_labels = decoder_input_ids
rag_kwargs["reduce_loss"] = True
assert decoder_input_ids is not None
outputs = self(
source_ids,
attention_mask=source_mask,
decoder_input_ids=decoder_input_ids,
use_cache=False,
labels=lm_labels,
return_dict=True,
**rag_kwargs,
)
loss = outputs["loss"]
return (loss,)
@property
def pad(self) -> int:
raise NotImplementedError("pad not implemented")
def training_step(self, batch, batch_idx) -> Dict:
loss_tensors = self._step(batch)
logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
# tokens per batch
tgt_pad_token_id = (
self.tokenizer.generator.pad_token_id
if isinstance(self.tokenizer, RagTokenizer)
else self.tokenizer.pad_token_id
)
src_pad_token_id = (
self.tokenizer.question_encoder.pad_token_id
if isinstance(self.tokenizer, RagTokenizer)
else self.tokenizer.pad_token_id
)
logs["tpb"] = (
batch["input_ids"].ne(src_pad_token_id).sum() + batch["decoder_input_ids"].ne(tgt_pad_token_id).sum()
)
return {"loss": loss_tensors[0], "log": logs}
def validation_step(self, batch, batch_idx) -> Dict:
return self._generative_step(batch)
def validation_epoch_end(self, outputs, prefix="val") -> Dict:
self.step_count += 1
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
loss = losses["loss"]
gen_metrics = {
k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]
}
metrics_tensor: torch.FloatTensor = torch.tensor(gen_metrics[self.val_metric]).type_as(loss)
gen_metrics.update({k: v.item() for k, v in losses.items()})
# fix for https://github.com/PyTorchLightning/pytorch-lightning/issues/2424
if dist.is_initialized():
dist.all_reduce(metrics_tensor, op=dist.ReduceOp.SUM)
metrics_tensor = metrics_tensor / dist.get_world_size()
gen_metrics.update({self.val_metric: metrics_tensor.item()})
losses.update(gen_metrics)
metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
metrics["step_count"] = self.step_count
self.save_metrics(metrics, prefix) # writes to self.metrics_save_path
preds = flatten_list([x["preds"] for x in outputs])
return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_{self.val_metric}": metrics_tensor}
def save_metrics(self, latest_metrics, type_path) -> None:
self.metrics[type_path].append(latest_metrics)
save_json(self.metrics, self.metrics_save_path)
def calc_generative_metrics(self, preds, target) -> Dict:
return calculate_exact_match(preds, target)
def _generative_step(self, batch: dict) -> dict:
start_time = time.time()
generated_ids = self.model.generate(
batch["input_ids"],
do_deduplication=False, # rag specific parameter
use_cache=True,
min_length=1,
max_length=self.target_lens["val"],
)
gen_time = (time.time() - start_time) / batch["input_ids"].shape[0]
preds: List[str] = self.ids_to_clean_text(generated_ids)
target: List[str] = self.ids_to_clean_text(batch["decoder_input_ids"])
loss_tensors = self._step(batch)
base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
gen_metrics: Dict = self.calc_generative_metrics(preds, target)
summ_len = np.mean(lmap(len, generated_ids))
base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **gen_metrics)
return base_metrics
def test_step(self, batch, batch_idx):
return self._generative_step(batch)
def test_epoch_end(self, outputs):
return self.validation_epoch_end(outputs, prefix="test")
def get_dataset(self, type_path) -> Seq2SeqDataset:
n_obs = self.n_obs[type_path]
max_target_length = self.target_lens[type_path]
dataset = Seq2SeqDataset(
self.tokenizer,
type_path=type_path,
n_obs=n_obs,
max_target_length=max_target_length,
**self.dataset_kwargs,
)
return dataset
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
dataset = self.get_dataset(type_path)
sampler = None
if self.hparams.sortish_sampler and type_path == "train":
assert self.hparams.gpus <= 1 # TODO: assert earlier
sampler = dataset.make_sortish_sampler(batch_size)
shuffle = False
dataloader = DataLoader(
dataset,
batch_size=batch_size,
collate_fn=dataset.collate_fn,
shuffle=shuffle,
num_workers=self.num_workers,
sampler=sampler,
)
return dataloader
def train_dataloader(self) -> DataLoader:
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
t_total = (
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
// self.hparams.accumulate_grad_batches
* float(self.hparams.max_epochs)
)
scheduler = get_linear_schedule_with_warmup(
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
)
if max(scheduler.get_last_lr()) > 0:
warnings.warn("All learning rates are 0")
self.lr_scheduler = scheduler
return dataloader
def val_dataloader(self) -> DataLoader:
return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)
def test_dataloader(self) -> DataLoader:
return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size)
@pl.utilities.rank_zero_only
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
save_path = self.output_dir.joinpath("checkpoint{}".format(self.step_count))
self.model.config.save_step = self.step_count
self.model.save_pretrained(save_path)
self.tokenizer.save_pretrained(save_path)
@staticmethod
def add_model_specific_args(parser, root_dir):
BaseTransformer.add_model_specific_args(parser, root_dir)
add_generic_args(parser, root_dir)
parser.add_argument(
"--max_source_length",
default=128,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument(
"--max_target_length",
default=25,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument(
"--val_max_target_length",
default=25,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument(
"--test_max_target_length",
default=25,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument("--sortish_sampler", action="store_true", default=False)
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
parser.add_argument("--n_val", type=int, default=-1, required=False, help="# examples. -1 means use all.")
parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
parser.add_argument(
"--prefix",
type=str,
default=None,
help="Prefix added at the beginning of each text, typically used with T5-based models.",
)
parser.add_argument(
"--early_stopping_patience",
type=int,
default=-1,
required=False,
help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
)
parser.add_argument(
"--distributed-port", type=int, default=-1, required=False, help="Port number for distributed training."
)
parser.add_argument(
"--model_type",
choices=["rag_sequence", "rag_token", "bart", "t5"],
type=str,
help="RAG model type: sequence or token, if none specified, the type is inferred from the model_name_or_path",
)
return parser
def main(args, model=None) -> GenerativeQAModule:
Path(args.output_dir).mkdir(exist_ok=True)
if model is None:
model: GenerativeQAModule = GenerativeQAModule(args)
dataset = Path(args.data_dir).name
if (
args.logger_name == "default"
or args.fast_dev_run
or str(args.output_dir).startswith("/tmp")
or str(args.output_dir).startswith("/var")
):
logger = True # don't pollute wandb logs unnecessarily
elif args.logger_name == "wandb":
from pytorch_lightning.loggers import WandbLogger
project = os.environ.get("WANDB_PROJECT", dataset)
logger = WandbLogger(name=model.output_dir.name, project=project)
elif args.logger_name == "wandb_shared":
from pytorch_lightning.loggers import WandbLogger
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
es_callback = (
get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
if args.early_stopping_patience >= 0
else False
)
trainer: pl.Trainer = generic_train(
model,
args,
logging_callback=Seq2SeqLoggingCallback(),
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
early_stopping_callback=es_callback,
logger=logger,
)
pickle_save(model.hparams, model.output_dir / "hparams.pkl")
if not args.do_predict:
return model
model.hparams.test_checkpoint = ""
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
if checkpoints:
model.hparams.test_checkpoint = checkpoints[-1]
trainer.resume_from_checkpoint = checkpoints[-1] # best checkpoint
trainer.logger.log_hyperparams(model.hparams)
# test() without a model tests using the best checkpoint automatically
trainer.test()
return model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args()
main(args)

34
examples/rag/finetune.sh Executable file
Просмотреть файл

@ -0,0 +1,34 @@
# Add parent directory to python path to access lightning_base.py
export PYTHONPATH="../":"${PYTHONPATH}"
# A sample finetuning run, you need to specify data_dir, output_dir and model_name_or_path
# run ./examples/rag/finetune.sh --help to see all the possible options
python examples/rag/finetune.py \
--data_dir $DATA_DIR \
--output_dir $OUTPUT_DIR \
--model_name_or_path $MODLE_NAME_OR_PATH \
--model_type rag_sequence \
--fp16 \
--gpus 8 \
--do_train \
--do_predict \
--n_val -1 \
--val_check_interval 0.25 \
--train_batch_size 8 \
--eval_batch_size 1 \
--max_source_length 128 \
--max_target_length 25 \
--val_max_target_length 25 \
--test_max_target_length 25 \
--label_smoothing 0.1 \
--dropout 0.1 \
--attention_dropout 0.1 \
--weight_decay 0.001 \
--adam_epsilon 1e-08 \
--max_grad_norm 0.1 \
--lr_scheduler polynomial \
--learning_rate 3e-05 \
--num_train_epochs 100 \
--warmup_steps 500 \
--gradient_accumulation_steps 1

Просмотреть файл

@ -0,0 +1,47 @@
"""
This script reads DPR retriever training data and parses each datapoint. We save a line per datapoint.
Each line consists of the query followed by a tab-separated list of Wikipedia page titles constituting
positive contexts for a given query.
"""
import argparse
import json
from tqdm import tqdm
def main():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--src_path",
type=str,
default="biencoder-nq-dev.json",
help="Path to raw DPR training data",
)
parser.add_argument(
"--evaluation_set",
type=str,
help="where to store parsed evaluation_set file",
)
parser.add_argument(
"--gold_data_path",
type=str,
help="where to store parsed gold_data_path file",
)
args = parser.parse_args()
with open(args.src_path, "r") as src_file, open(args.evaluation_set, "w") as eval_file, open(
args.gold_data_path, "w"
) as gold_file:
dpr_records = json.load(src_file)
for dpr_record in tqdm(dpr_records):
question = dpr_record["question"]
contexts = [context["title"] for context in dpr_record["positive_ctxs"]]
eval_file.write(question + "\n")
gold_file.write("\t".join(contexts) + "\n")
if __name__ == "__main__":
main()

Просмотреть файл

@ -0,0 +1,4 @@
faiss-cpu >= 1.6.3
datasets >= 1.0.1
psutil >= 5.7.0
torch >= 1.4.0

Просмотреть файл

@ -0,0 +1,156 @@
import json
import os
import shutil
import sys
import tempfile
import unittest
from unittest import TestCase
from unittest.mock import patch
import numpy as np
from datasets import Dataset
import faiss
from transformers.configuration_bart import BartConfig
from transformers.configuration_dpr import DPRConfig
from transformers.configuration_rag import RagConfig
from transformers.file_utils import is_datasets_available, is_faiss_available, is_psutil_available, is_torch_available
from transformers.tokenization_bart import BartTokenizer
from transformers.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
from transformers.tokenization_dpr import DPRQuestionEncoderTokenizer
from transformers.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip
from examples.rag.distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
def require_distributed_retrieval(test_case):
"""
Decorator marking a test that requires a set of dependencies necessary for pefrorm retrieval with
:class:`~transformers.RagRetriever`.
These tests are skipped when respective libraries are not installed.
"""
if not (is_torch_available() and is_datasets_available() and is_faiss_available() and is_psutil_available()):
test_case = unittest.skip("test requires PyTorch, Datasets, Faiss, psutil")(test_case)
return test_case
@require_distributed_retrieval
class RagRetrieverTest(TestCase):
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
self.retrieval_vector_size = 8
# DPR tok
vocab_tokens = [
"[UNK]",
"[CLS]",
"[SEP]",
"[PAD]",
"[MASK]",
"want",
"##want",
"##ed",
"wa",
"un",
"runn",
"##ing",
",",
"low",
"lowest",
]
dpr_tokenizer_path = os.path.join(self.tmpdirname, "dpr_tokenizer")
os.makedirs(dpr_tokenizer_path, exist_ok=True)
self.vocab_file = os.path.join(dpr_tokenizer_path, DPR_VOCAB_FILES_NAMES["vocab_file"])
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
# BART tok
vocab = [
"l",
"o",
"w",
"e",
"r",
"s",
"t",
"i",
"d",
"n",
"\u0120",
"\u0120l",
"\u0120n",
"\u0120lo",
"\u0120low",
"er",
"\u0120lowest",
"\u0120newer",
"\u0120wider",
"<unk>",
]
vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
self.special_tokens_map = {"unk_token": "<unk>"}
bart_tokenizer_path = os.path.join(self.tmpdirname, "bart_tokenizer")
os.makedirs(bart_tokenizer_path, exist_ok=True)
self.vocab_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["vocab_file"])
self.merges_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["merges_file"])
with open(self.vocab_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(vocab_tokens) + "\n")
with open(self.merges_file, "w", encoding="utf-8") as fp:
fp.write("\n".join(merges))
def get_dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer:
return DPRQuestionEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer"))
def get_bart_tokenizer(self) -> BartTokenizer:
return BartTokenizer.from_pretrained(os.path.join(self.tmpdirname, "bart_tokenizer"))
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def get_dummy_pytorch_distributed_retriever(self, init_retrieval, port=12345) -> RagPyTorchDistributedRetriever:
dataset = Dataset.from_dict(
{
"id": ["0", "1"],
"text": ["foo", "bar"],
"title": ["Foo", "Bar"],
"embeddings": [np.ones(self.retrieval_vector_size), 2 * np.ones(self.retrieval_vector_size)],
}
)
dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT)
config = RagConfig(
retrieval_vector_size=self.retrieval_vector_size,
question_encoder=DPRConfig().to_dict(),
generator=BartConfig().to_dict(),
)
with patch("transformers.retrieval_rag.load_dataset") as mock_load_dataset:
mock_load_dataset.return_value = dataset
retriever = RagPyTorchDistributedRetriever(
config,
question_encoder_tokenizer=self.get_dpr_tokenizer(),
generator_tokenizer=self.get_bart_tokenizer(),
)
if init_retrieval:
retriever.init_retrieval(port)
return retriever
def test_pytorch_distributed_retriever_retrieve(self):
n_docs = 1
retriever = self.get_dummy_pytorch_distributed_retriever(init_retrieval=True)
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
self.assertEqual(len(doc_dicts), 2)
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
self.assertListEqual(list(doc_ids), [1, 0])

187
examples/rag/utils.py Normal file
Просмотреть файл

@ -0,0 +1,187 @@
import linecache
import re
import string
from collections import Counter
from logging import getLogger
from pathlib import Path
from typing import Dict, List
import torch
from torch.utils.data import Dataset
from examples.seq2seq.utils import SortishSampler, trim_batch
from transformers import BartTokenizer, RagTokenizer, T5Tokenizer
def encode_line(tokenizer, line, max_length, padding_side, pad_to_max_length=True, return_tensors="pt"):
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) and not line.startswith(" ") else {}
tokenizer.padding_side = padding_side
return tokenizer(
[line],
max_length=max_length,
padding="max_length" if pad_to_max_length else None,
truncation=True,
return_tensors=return_tensors,
add_special_tokens=True,
**extra_kw,
)
class Seq2SeqDataset(Dataset):
def __init__(
self,
tokenizer,
data_dir,
max_source_length,
max_target_length,
type_path="train",
n_obs=None,
src_lang=None,
tgt_lang=None,
prefix="",
):
super().__init__()
self.src_file = Path(data_dir).joinpath(type_path + ".source")
self.tgt_file = Path(data_dir).joinpath(type_path + ".target")
self.src_lens = self.get_char_lens(self.src_file)
self.max_source_length = max_source_length
self.max_target_length = max_target_length
assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
self.tokenizer = tokenizer
self.prefix = prefix
if n_obs is not None:
self.src_lens = self.src_lens[:n_obs]
self.src_lang = src_lang
self.tgt_lang = tgt_lang
def __len__(self):
return len(self.src_lens)
def __getitem__(self, index) -> Dict[str, torch.Tensor]:
index = index + 1 # linecache starts at 1
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
assert source_line, f"empty source line for index {index}"
assert tgt_line, f"empty tgt line for index {index}"
# Need to add eos token manually for T5
if isinstance(self.tokenizer, T5Tokenizer):
source_line += self.tokenizer.eos_token
tgt_line += self.tokenizer.eos_token
# Pad source and target to the right
source_tokenizer = (
self.tokenizer.question_encoder if isinstance(self.tokenizer, RagTokenizer) else self.tokenizer
)
target_tokenizer = self.tokenizer.generator if isinstance(self.tokenizer, RagTokenizer) else self.tokenizer
source_inputs = encode_line(source_tokenizer, source_line, self.max_source_length, "right")
target_inputs = encode_line(target_tokenizer, tgt_line, self.max_target_length, "right")
source_ids = source_inputs["input_ids"].squeeze()
target_ids = target_inputs["input_ids"].squeeze()
src_mask = source_inputs["attention_mask"].squeeze()
return {
"input_ids": source_ids,
"attention_mask": src_mask,
"decoder_input_ids": target_ids,
}
@staticmethod
def get_char_lens(data_file):
return [len(x) for x in Path(data_file).open().readlines()]
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
input_ids = torch.stack([x["input_ids"] for x in batch])
masks = torch.stack([x["attention_mask"] for x in batch])
target_ids = torch.stack([x["decoder_input_ids"] for x in batch])
tgt_pad_token_id = (
self.tokenizer.generator.pad_token_id
if isinstance(self.tokenizer, RagTokenizer)
else self.tokenizer.pad_token_id
)
src_pad_token_id = (
self.tokenizer.question_encoder.pad_token_id
if isinstance(self.tokenizer, RagTokenizer)
else self.tokenizer.pad_token_id
)
y = trim_batch(target_ids, tgt_pad_token_id)
source_ids, source_mask = trim_batch(input_ids, src_pad_token_id, attention_mask=masks)
batch = {
"input_ids": source_ids,
"attention_mask": source_mask,
"decoder_input_ids": y,
}
return batch
def make_sortish_sampler(self, batch_size):
return SortishSampler(self.src_lens, batch_size)
logger = getLogger(__name__)
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def f1_score(prediction, ground_truth):
prediction_tokens = normalize_answer(prediction).split()
ground_truth_tokens = normalize_answer(ground_truth).split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def exact_match_score(prediction, ground_truth):
return normalize_answer(prediction) == normalize_answer(ground_truth)
def calculate_exact_match(output_lns: List[str], reference_lns: List[str]) -> Dict:
assert len(output_lns) == len(reference_lns)
em = 0
for hypo, pred in zip(output_lns, reference_lns):
em += exact_match_score(hypo, pred)
if len(output_lns) > 0:
em /= len(output_lns)
return {"em": em}
def is_rag_model(model_prefix):
return model_prefix.startswith("rag")
def set_extra_model_params(extra_params, hparams, config):
equivalent_param = {p: p for p in extra_params}
# T5 models don't have `dropout` param, they have `dropout_rate` instead
equivalent_param["dropout"] = "dropout_rate"
for p in extra_params:
if getattr(hparams, p, None):
if not hasattr(config, p) and not hasattr(config, equivalent_param[p]):
logger.info("config doesn't have a `{}` attribute".format(p))
delattr(hparams, p)
continue
set_p = p if hasattr(config, p) else equivalent_param[p]
setattr(config, set_p, getattr(hparams, p))
delattr(hparams, p)
return hparams, config

Просмотреть файл

@ -8,7 +8,7 @@ tensorflow_datasets
pytorch-lightning==0.8.5
matplotlib
git-python==1.0.3
faiss
faiss-cpu
streamlit
elasticsearch
pandas

Просмотреть файл

@ -10,7 +10,7 @@ known_third_party =
datasets
elasticsearch
fairseq
faiss
faiss-cpu
fastprogress
fire
fugashi

Просмотреть файл

@ -89,7 +89,8 @@ extras["onnxruntime"] = ["onnxruntime>=1.4.0", "onnxruntime-tools>=1.4.2"]
extras["serving"] = ["pydantic", "uvicorn", "fastapi", "starlette"]
extras["all"] = extras["serving"] + ["tensorflow", "torch"]
extras["testing"] = ["pytest", "pytest-xdist", "timeout-decorator", "psutil", "parameterized"]
extras["retrieval"] = ["faiss-cpu", "datasets"]
extras["testing"] = ["pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil"] + extras["retrieval"]
# sphinx-rtd-theme==0.5.0 introduced big changes in the style.
extras["docs"] = ["recommonmark", "sphinx", "sphinx-markdown-tables", "sphinx-rtd-theme==0.4.3", "sphinx-copybutton"]
extras["quality"] = ["black >= 20.8b1", "isort >= 5", "flake8 >= 3.8.3"]

Просмотреть файл

@ -42,6 +42,7 @@ from .configuration_mmbt import MMBTConfig
from .configuration_mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
from .configuration_pegasus import PegasusConfig
from .configuration_rag import RagConfig
from .configuration_reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
@ -86,6 +87,7 @@ from .file_utils import (
cached_path,
is_apex_available,
is_datasets_available,
is_faiss_available,
is_psutil_available,
is_py3nvml_available,
is_tf_available,
@ -140,6 +142,9 @@ from .pipelines import (
pipeline,
)
# Retriever
from .retrieval_rag import RagRetriever
# Tokenizers
from .tokenization_albert import AlbertTokenizer
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
@ -172,6 +177,7 @@ from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFas
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
from .tokenization_pegasus import PegasusTokenizer
from .tokenization_phobert import PhobertTokenizer
from .tokenization_rag import RagTokenizer
from .tokenization_reformer import ReformerTokenizer
from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
@ -416,6 +422,7 @@ if is_torch_available():
load_tf_weights_in_openai_gpt,
)
from .modeling_pegasus import PegasusForConditionalGeneration
from .modeling_rag import RagModel, RagSequenceForGeneration, RagTokenForGeneration
from .modeling_reformer import (
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
ReformerAttention,

Просмотреть файл

@ -24,6 +24,7 @@ from .configuration_bert_generation import BertGenerationConfig
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig
from .configuration_dpr import DPR_PRETRAINED_CONFIG_ARCHIVE_MAP, DPRConfig
from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig
from .configuration_encoder_decoder import EncoderDecoderConfig
from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig
@ -38,6 +39,7 @@ from .configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfi
from .configuration_mobilebert import MobileBertConfig
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
from .configuration_pegasus import PegasusConfig
from .configuration_rag import RagConfig
from .configuration_reformer import ReformerConfig
from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
@ -75,6 +77,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP,
LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
DPR_PRETRAINED_CONFIG_ARCHIVE_MAP,
]
for key, value, in pretrained_map.items()
)
@ -110,7 +113,9 @@ CONFIG_MAPPING = OrderedDict(
("encoder-decoder", EncoderDecoderConfig),
("funnel", FunnelConfig),
("lxmert", LxmertConfig),
("dpr", DPRConfig),
("layoutlm", LayoutLMConfig),
("rag", RagConfig),
]
)
@ -145,6 +150,8 @@ MODEL_NAMES_MAPPING = OrderedDict(
("funnel", "Funnel Transformer"),
("lxmert", "LXMERT"),
("layoutlm", "LayoutLM"),
("dpr", "DPR"),
("rag", "RAG"),
]
)

Просмотреть файл

@ -14,7 +14,7 @@
# limitations under the License.
""" DPR model configuration """
from .configuration_bert import BertConfig
from .configuration_utils import PretrainedConfig
from .utils import logging
@ -27,7 +27,7 @@ DPR_PRETRAINED_CONFIG_ARCHIVE_MAP = {
}
class DPRConfig(BertConfig):
class DPRConfig(PretrainedConfig):
r"""
:class:`~transformers.DPRConfig` is the configuration class to store the configuration of a
`DPRModel`.
@ -36,12 +36,73 @@ class DPRConfig(BertConfig):
It is used to instantiate the components of the DPR model.
Args:
projection_dim (:obj:`int`, optional, defaults to 0):
vocab_size (:obj:`int`, `optional`, defaults to 30522):
Vocabulary size of the DPR model. Defines the different tokens that
can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.BertModel`.
hidden_size (:obj:`int`, `optional`, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (:obj:`int`, `optional`, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (:obj:`int`, `optional`, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler.
If string, :obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"` and :obj:`"gelu_new"` are supported.
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout ratio for the attention probabilities.
max_position_embeddings (:obj:`int`, `optional`, defaults to 512):
The maximum sequence length that this model might ever be used with.
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
type_vocab_size (:obj:`int`, `optional`, defaults to 2):
The vocabulary size of the `token_type_ids` passed into :class:`~transformers.BertModel`.
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
projection_dim (:obj:`int`, `optional`, defaults to 0):
Dimension of the projection for the context and question encoders.
If it is set to zero (default), then no projection is done.
"""
model_type = "dpr"
def __init__(self, projection_dim: int = 0, **kwargs): # projection of the encoders, 0 for no projection
super().__init__(**kwargs)
def __init__(
self,
vocab_size=30522,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
gradient_checkpointing=False,
projection_dim: int = 0,
**kwargs
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing
self.projection_dim = projection_dim

Просмотреть файл

@ -0,0 +1,176 @@
# coding=utf-8
# Copyright 2020, The RAG Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" RAG model configuration """
import copy
from .configuration_utils import PretrainedConfig
from .file_utils import add_start_docstrings
RAG_CONFIG_DOC = r"""
:class:`~transformers.RagConfig` stores the configuration of a `RagModel`.
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used
to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig`
for more information.
Args:
title_sep (:obj:`str`, `optional`, defaults to ``" / "``):
Separator inserted between the title and the text of the retrieved document when calling :class:`~transformers.RagRetriever`.
doc_sep (:obj:`str`, `optional`, defaults to ``" // "``):
Separator inserted between the the text of the retrieved document and the original input when calliang :class:`~transformers.RagRetriever`.
n_docs (:obj:`int`, `optional`, defaults to 5):
Number of documents to retrieve.
max_combined_length (:obj:`int`, `optional`, defaults to 300):
Max length of contextualized input returned by :meth:`~transformers.RagRetriever.__call__`.
retrieval_vector_size (:obj:`int`, `optional`, defaults to 768):
Dimensionality of the document embeddings indexed by :class:`~transformers.RagRetriever`.
retrieval_batch_size (:obj:`int`, `optional`, defaults to 8):
Retrieval batch size, defined as the number of queries issues concurrently to the faiss index excapsulated :class:`~transformers.RagRetriever`.
dataset (:obj:`str`, `optional`, defaults to :obj:`"wiki_dpr"`):
A datatset identifier of the indexed dataset on HuggingFace AWS bucket (list all available datasets and ids using :obj:`datasets.list_datasets()`).
dataset_split (:obj:`str`, `optional`, defaults to :obj:`train`)
Which split of the ``dataset`` to load.
index_name (:obj:`str`, `optional`, defaults to :obj:`compressed`)
The index_name of the index associated with the :obj:`dataset`. One can choose between :obj:`legacy`, :obj:`exact` and :obj:`compressed`.
index_path (:obj:`str`, `optional`)
The path to the serialized faiss index on disk.
passages_path: (:obj:`str`, `optional`):
A path to text passages compatible with the faiss index. Required if using :class:`~transformers.retrieval_rag.LegacyIndex`
use_dummy_dataset (:obj:`bool`, `optional`, defaults to ``False``)
Whether to load a "dummy" variant of the dataset specified by :obj:`dataset`.
label_smoothing (:obj:`float`, `optional`, defaults to 0.0):
Only relevant if ``return_loss`` is set to :obj:`True`. Controls the ``epsilon`` parameter value for label smoothing in the loss calculation.
If set to ``0.0``, no label smoothing is performed.
do_marginalize (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`True`, the logits are marginalized over all documents
by making use of ``torch.nn.functional.log_softmax``.
reduce_loss (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`True`, the NLL loss is reduced using the ``torch.Tensor.sum`` operation.
do_deduplication (:obj:`bool`, `optional`, defaults to :obj:`True`):
Controls whether we want to deduplicate the generations from different context documents for a given input.
Has to be set to :obj:`False` if used while training with distributed backend.
exclude_bos_score (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`True`, the score of the BOS token is disregarded when computing
the loss.
output_retrieved(:obj:`bool`, `optional`, defaults to :obj:`False`):
If set to ``True``, :obj:`retrieved_doc_embeds`, :obj:`retrieved_doc_ids`, :obj:`context_input_ids` and :obj:`context_attention_mask` are returned. See returned tensors for more detail.
"""
@add_start_docstrings(RAG_CONFIG_DOC)
class RagConfig(PretrainedConfig):
model_type = "rag"
def __init__(
self,
vocab_size=None,
is_encoder_decoder=True,
prefix=None,
bos_token_id=None,
pad_token_id=None,
eos_token_id=None,
decoder_start_token_id=None,
title_sep=" / ",
doc_sep=" // ",
n_docs=5,
max_combined_length=300,
retrieval_vector_size=768,
retrieval_batch_size=8,
dataset="wiki_dpr",
dataset_split="train",
index_name="compressed",
index_path=None,
passages_path=None,
use_dummy_dataset=False,
reduce_loss=False,
label_smoothing=0.0,
do_deduplication=True,
exclude_bos_score=False,
do_marginalize=False,
output_retrieved=False,
**kwargs
):
super().__init__(
bos_token_id=bos_token_id,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id,
is_encoder_decoder=is_encoder_decoder,
prefix=prefix,
vocab_size=vocab_size,
**kwargs,
)
assert (
"question_encoder" in kwargs and "generator" in kwargs
), "Config has to be initialized with question_encoder and generator config"
question_encoder_config = kwargs.pop("question_encoder")
question_encoder_model_type = question_encoder_config.pop("model_type")
decoder_config = kwargs.pop("generator")
decoder_model_type = decoder_config.pop("model_type")
from .configuration_auto import AutoConfig
self.question_encoder = AutoConfig.for_model(question_encoder_model_type, **question_encoder_config)
self.generator = AutoConfig.for_model(decoder_model_type, **decoder_config)
self.reduce_loss = reduce_loss
self.label_smoothing = label_smoothing
self.exclude_bos_score = exclude_bos_score
self.do_marginalize = do_marginalize
self.title_sep = title_sep
self.doc_sep = doc_sep
self.n_docs = n_docs
self.max_combined_length = max_combined_length
self.dataset = dataset
self.dataset_split = dataset_split
self.index_name = index_name
self.retrieval_vector_size = retrieval_vector_size
self.retrieval_batch_size = retrieval_batch_size
self.passages_path = passages_path
self.index_path = index_path
self.use_dummy_dataset = use_dummy_dataset
self.output_retrieved = output_retrieved
self.do_deduplication = do_deduplication
@classmethod
def from_question_encoder_generator_configs(
cls, question_encoder_config: PretrainedConfig, generator_config: PretrainedConfig, **kwargs
) -> PretrainedConfig:
r"""
Instantiate a :class:`~transformers.EncoderDecoderConfig` (or a derived class) from a pre-trained encoder model configuration and decoder model configuration.
Returns:
:class:`EncoderDecoderConfig`: An instance of a configuration object
"""
return cls(question_encoder=question_encoder_config.to_dict(), generator=generator_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default :meth:`~transformers.PretrainedConfig.to_dict`.
Returns:
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["question_encoder"] = self.question_encoder.to_dict()
output["generator"] = self.generator.to_dict()
output["model_type"] = self.__class__.model_type
return output

Просмотреть файл

@ -69,6 +69,7 @@ try:
import datasets # noqa: F401
_datasets_available = True
logger.debug(f"Succesfully imported datasets version {datasets.__version__}")
except ImportError:
_datasets_available = False
@ -119,6 +120,16 @@ try:
except ImportError:
_has_apex = False
try:
import faiss # noqa: F401
_faiss_available = True
logger.debug(f"Succesfully imported faiss version {faiss.__version__}")
except ImportError:
_faiss_available = False
default_cache_path = os.path.join(torch_cache_home, "transformers")
@ -175,6 +186,10 @@ def is_apex_available():
return _has_apex
def is_faiss_available():
return _faiss_available
def add_start_docstrings(*docstr):
def docstring_decorator(fn):
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")

Просмотреть файл

@ -27,6 +27,7 @@ from .configuration_auto import (
CamembertConfig,
CTRLConfig,
DistilBertConfig,
DPRConfig,
ElectraConfig,
EncoderDecoderConfig,
FlaubertConfig,
@ -97,6 +98,7 @@ from .modeling_distilbert import (
DistilBertForTokenClassification,
DistilBertModel,
)
from .modeling_dpr import DPRQuestionEncoder
from .modeling_electra import (
ElectraForMaskedLM,
ElectraForMultipleChoice,
@ -148,6 +150,11 @@ from .modeling_mobilebert import (
)
from .modeling_openai import OpenAIGPTLMHeadModel, OpenAIGPTModel
from .modeling_pegasus import PegasusForConditionalGeneration
from .modeling_rag import ( # noqa: F401 - need to import all RagModels to be in globals() function
RagModel,
RagSequenceForGeneration,
RagTokenForGeneration,
)
from .modeling_reformer import (
ReformerForMaskedLM,
ReformerForQuestionAnswering,
@ -224,6 +231,7 @@ MODEL_MAPPING = OrderedDict(
(FunnelConfig, FunnelModel),
(LxmertConfig, LxmertModel),
(BertGenerationConfig, BertGenerationEncoder),
(DPRConfig, DPRQuestionEncoder),
]
)
@ -412,7 +420,6 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
]
)
AUTO_MODEL_PRETRAINED_DOCSTRING = r"""
The model class to instantiate is selected based on the :obj:`model_type` property of the config object

Просмотреть файл

@ -272,6 +272,7 @@ class DPRPretrainedContextEncoder(PreTrainedModel):
config_class = DPRConfig
load_tf_weights = None
base_model_prefix = "ctx_encoder"
authorized_missing_keys = [r"position_ids"]
def init_weights(self):
self.ctx_encoder.init_weights()
@ -285,6 +286,7 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel):
config_class = DPRConfig
load_tf_weights = None
base_model_prefix = "question_encoder"
authorized_missing_keys = [r"position_ids"]
def init_weights(self):
self.question_encoder.init_weights()
@ -298,6 +300,7 @@ class DPRPretrainedReader(PreTrainedModel):
config_class = DPRConfig
load_tf_weights = None
base_model_prefix = "span_predictor"
authorized_missing_keys = [r"position_ids"]
def init_weights(self):
self.span_predictor.encoder.init_weights()

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,470 @@
# coding=utf-8
# Copyright 2020, The RAG Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RAG Retriever model implementation."""
import os
import pickle
import time
from typing import Iterable, List, Optional, Tuple
import numpy as np
from .configuration_rag import RagConfig
from .file_utils import cached_path, is_datasets_available, is_faiss_available, is_remote_url
from .tokenization_rag import RagTokenizer
from .tokenization_utils_base import BatchEncoding
from .utils import logging
if is_datasets_available() and is_faiss_available():
from datasets import load_dataset
import faiss
logger = logging.get_logger(__name__)
LEGACY_INDEX_PATH = "https://storage.googleapis.com/huggingface-nlp/datasets/wiki_dpr/"
class Index:
"""
A base class for the Indices encapsulated by the :class:`~transformers.RagRetriever`.
"""
def get_doc_dicts(self, doc_ids: np.ndarray) -> List[dict]:
"""
Returns a list of dictionaries, containing titles and text of the retrieved documents.
Args:
doc_ids (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs)`):
A tensor of document indices.
"""
raise NotImplementedError
def get_top_docs(self, question_hidden_states: np.ndarray, n_docs=5) -> Tuple[np.ndarray, np.ndarray]:
"""
For each query in the batch, retrieves ``n_docs`` documents.
Args:
question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size):
An array of query vectors.
n_docs (:obj:`int`):
The number of docs retrieved per query.
Returns:
:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs)`: A tensor of indices of retrieved documents.
:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`: A tensor of vector representations of retrieved documents.
"""
raise NotImplementedError
def is_initialized(self):
"""
Returns :obj:`True` if index is already initialized.
"""
raise NotImplementedError
def init_index(self):
"""
A function responsible for loading the index into memory. Should be called only once per training run of a RAG model.
E.g. if the model is trained on multiple GPUs in a distributed setup, only one of the workers will load the index.
"""
raise NotImplementedError
class LegacyIndex:
"""
An index which can be deserialized from the files built using https://github.com/facebookresearch/DPR.
We use default faiss index parameters as specified in that repository.
Args:
vector_size (:obj:`int`):
The dimension of indexed vectors.
index_path (:obj:`str`):
A path to a `directory` containing index files compatible with
:class:`~transformers.retrieval_rag.LegacyIndex`
"""
INDEX_FILENAME = "hf_bert_base.hnswSQ8_correct_phi_128.c_index"
PASSAGE_FILENAME = "psgs_w100.tsv.pkl"
def __init__(self, vector_size, index_path):
self.index_id_to_db_id = []
self.index_path = index_path
self.passages = self._load_passages()
self.vector_size = vector_size
self.index = None
self._index_initialize = False
def _resolve_path(self, index_path, filename):
assert os.path.isdir(index_path) or is_remote_url(index_path), "Please specify a valid ``index_path``."
archive_file = os.path.join(index_path, filename)
try:
# Load from URL or cache if already cached
resolved_archive_file = cached_path(archive_file)
if resolved_archive_file is None:
raise EnvironmentError
except EnvironmentError:
msg = (
f"Can't load '{archive_file}'. Make sure that:\n\n"
f"- '{index_path}' is a correct remote path to a directory containing a file named {filename}"
f"- or '{index_path}' is the correct path to a directory containing a file named {filename}.\n\n"
)
raise EnvironmentError(msg)
if resolved_archive_file == archive_file:
logger.info("loading file {}".format(archive_file))
else:
logger.info("loading file {} from cache at {}".format(archive_file, resolved_archive_file))
return resolved_archive_file
def _load_passages(self):
logger.info("Loading passages from {}".format(self.index_path))
passages_path = self._resolve_path(self.index_path, self.PASSAGE_FILENAME)
with open(passages_path, "rb") as passages_file:
passages = pickle.load(passages_file)
return passages
def _deserialize_index(self):
logger.info("Loading index from {}".format(self.index_path))
resolved_index_path = self._resolve_path(self.index_path, self.INDEX_FILENAME + ".index.dpr")
self.index = faiss.read_index(resolved_index_path)
resolved_meta_path = self._resolve_path(self.index_path, self.INDEX_FILENAME + ".index_meta.dpr")
with open(resolved_meta_path, "rb") as metadata_file:
self.index_id_to_db_id = pickle.load(metadata_file)
assert (
len(self.index_id_to_db_id) == self.index.ntotal
), "Deserialized index_id_to_db_id should match faiss index size"
def is_initialized(self):
return self._index_initialize
def init_index(self):
index = faiss.IndexHNSWFlat(self.vector_size + 1, 512)
index.hnsw.efSearch = 128
index.hnsw.efConstruction = 200
self.index = index
self._deserialize_index()
self._index_initialize = True
def get_doc_dicts(self, doc_ids: np.array):
doc_list = []
for doc_ids_i in doc_ids:
ids = [str(int(doc_id)) for doc_id in doc_ids_i]
docs = [self.passages[doc_id] for doc_id in ids]
doc_list.append(docs)
doc_dicts = []
for docs in doc_list:
doc_dict = {}
doc_dict["title"] = [doc[1] for doc in docs]
doc_dict["text"] = [doc[0] for doc in docs]
doc_dicts.append(doc_dict)
return doc_dicts
def get_top_docs(self, question_hidden_states: np.ndarray, n_docs=5) -> Tuple[np.ndarray, np.ndarray]:
aux_dim = np.zeros(len(question_hidden_states), dtype="float32").reshape(-1, 1)
query_nhsw_vectors = np.hstack((question_hidden_states, aux_dim))
_, docs_ids = self.index.search(query_nhsw_vectors, n_docs)
vectors = [[self.index.reconstruct(int(doc_id))[:-1] for doc_id in doc_ids] for doc_ids in docs_ids]
ids = [[int(self.index_id_to_db_id[doc_id]) for doc_id in doc_ids] for doc_ids in docs_ids]
return np.array(ids), np.array(vectors)
class HFIndex:
"""
A wrapper around an instance of :class:`~datasets.Datasets`. If ``index_path`` is set to ``None``,
we load the pre-computed index available with the :class:`~datasets.arrow_dataset.Dataset`, otherwise, we load the index from the indicated path on disk.
Args:
dataset (:obj:`str`, optional, defaults to ``wiki_dpr``):
A datatset identifier of the indexed dataset on HuggingFace AWS bucket (list all available datasets and ids with ``datasets.list_datasets()``).
dataset_split (:obj:`str`, optional, defaults to ``train``)
Which split of the ``dataset`` to load.
index_name (:obj:`str`, optional, defaults to ``train``)
The index_name of the index associated with the ``dataset``. The index loaded from ``index_path`` will be saved under this name.
index_path (:obj:`str`, optional, defaults to ``None``)
The path to the serialized faiss index on disk.
"""
def __init__(
self,
dataset_name: str,
dataset_split: str,
index_name: str,
index_path: Optional[str] = None,
use_dummy_dataset=False,
):
super().__init__()
self.dataset_name = dataset_name
self.dataset_split = dataset_split
self.index_name = index_name
self.index_path = index_path
self.use_dummy_dataset = use_dummy_dataset
self._index_initialize = False
logger.info("Loading passages from {}".format(self.dataset_name))
self.dataset = load_dataset(
self.dataset_name, with_index=False, split=self.dataset_split, dummy=self.use_dummy_dataset
)
def is_initialized(self):
return self._index_initialize
def init_index(self):
if self.index_path is not None:
logger.info("Loading index from {}".format(self.index_path))
self.index.load_faiss_index(index_name=self.index_name, file=self.index_path)
else:
logger.info("Loading index from {}".format(self.dataset_name + " with index name " + self.index_name))
self.dataset = load_dataset(
self.dataset_name,
with_embeddings=True,
with_index=True,
split=self.dataset_split,
index_name=self.index_name,
dummy=self.use_dummy_dataset,
)
self._index_initialize = True
def get_doc_dicts(self, doc_ids: np.ndarray) -> List[dict]:
return [self.dataset[doc_ids[i].tolist()] for i in range(doc_ids.shape[0])]
def get_top_docs(self, question_hidden_states: np.ndarray, n_docs=5) -> Tuple[np.ndarray, np.ndarray]:
_, docs = self.dataset.get_nearest_examples_batch("embeddings", question_hidden_states, n_docs)
ids = [[int(i) for i in doc["id"]] for doc in docs]
vectors = [doc["embeddings"] for doc in docs]
return np.array(ids), np.array(vectors) # shapes (batch_size, n_docs) and (batch_size, n_docs, d)
class RagRetriever:
"""
Retriever used to get documents from vector queries.
It retrieves the documents embeddings as well as the documents contents, and it formats them to be used with a RagModel.
Args:
config (:class:`~transformers.RagConfig`):
The configuration of the RAG model this Retriever is used with. Contains parameters indicating which ``Index`` to build.
question_encoder_tokenizer (:class:`~transformers.PreTrainedTokenizer`):
The tokenizer that was used to tokenize the question.
It is used to decode the question and then use the generator_tokenizer.
generator_tokenizer (:class:`~transformers.PreTrainedTokenizer`):
The tokenizer used for the generator part of the RagModel.
"""
_init_retrieval = True
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer):
super().__init__()
self.index = (
LegacyIndex(
config.retrieval_vector_size,
config.index_path or LEGACY_INDEX_PATH,
)
if config.index_name == "legacy"
else HFIndex(
config.dataset, config.dataset_split, config.index_name, config.index_path, config.use_dummy_dataset
)
)
self.generator_tokenizer = generator_tokenizer
self.question_encoder_tokenizer = question_encoder_tokenizer
self.n_docs = config.n_docs
self.batch_size = config.retrieval_batch_size
self.config = config
if self._init_retrieval:
self.init_retrieval()
@classmethod
def from_pretrained(cls, retriever_name_or_path, **kwargs):
config = RagConfig.from_pretrained(retriever_name_or_path, **kwargs)
rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config)
question_encoder_tokenizer = rag_tokenizer.question_encoder
generator_tokenizer = rag_tokenizer.generator
return cls(
config, question_encoder_tokenizer=question_encoder_tokenizer, generator_tokenizer=generator_tokenizer
)
def save_pretrained(self, save_directory):
self.config.save_pretrained(save_directory)
rag_tokenizer = RagTokenizer(
question_encoder_tokenizer=self.question_encoder_tokenizer,
generator_tokenizer=self.generator_tokenizer,
)
rag_tokenizer.save_pretrained(save_directory)
def init_retrieval(self):
"""
Retriever initalization function. It loads the index into memory.
"""
logger.info("initializing retrieval")
self.index.init_index()
def postprocess_docs(self, docs, input_strings, prefix, n_docs, return_tensors=None):
r"""
Postprocessing retrieved ``docs`` and combining them with ``input_strings``.
Args:
doc_scores (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs)`):
Retrieval scores of respective docs - passed for logging.
docs (:obj:`dict`):
Retrieved documents.
input_strings (:obj:`str`):
Input strings decoded by ``preprocess_query``.
prefix (:obj:`str`):
Prefix added at the beginning of each input, typically used with T5-based models.
Return:
:obj:`tuple(tensors)`:
a tuple consisting of two elements: contextualized ``input_ids`` and a compatible ``attention_mask``.
"""
def cat_input_and_doc(doc_title, doc_text, input_string, prefix):
# TODO(Patrick): if we train more RAG models, I want to put the input first to take advantage of effortless truncation
# TODO(piktus): better handling of truncation
if doc_title.startswith('"'):
doc_title = doc_title[1:]
if doc_title.endswith('"'):
doc_title = doc_title[:-1]
if prefix is None:
prefix = ""
out = (prefix + doc_title + self.config.title_sep + doc_text + self.config.doc_sep + input_string).replace(
" ", " "
)
return out
rag_input_strings = [
cat_input_and_doc(
docs[i]["title"][j],
docs[i]["text"][j],
input_strings[i],
prefix,
)
for i in range(len(docs))
for j in range(n_docs)
]
contextualized_inputs = self.generator_tokenizer.batch_encode_plus(
rag_input_strings,
max_length=self.config.max_combined_length,
return_tensors=return_tensors,
padding="max_length",
truncation=True,
)
return contextualized_inputs["input_ids"], contextualized_inputs["attention_mask"]
def _chunk_tensor(self, t: Iterable, chunk_size: int) -> List[Iterable]:
return [t[i : i + chunk_size] for i in range(0, len(t), chunk_size)]
def _main_retrieve(self, question_hidden_states: np.ndarray, n_docs: int) -> Tuple[np.ndarray, np.ndarray]:
question_hidden_states_batched = self._chunk_tensor(question_hidden_states, self.batch_size)
ids_batched = []
vectors_batched = []
for question_hidden_states in question_hidden_states_batched:
start_time = time.time()
ids, vectors = self.index.get_top_docs(question_hidden_states, n_docs)
logger.debug(
"index search time: {} sec, batch size {}".format(
time.time() - start_time, question_hidden_states.shape
)
)
ids_batched.extend(ids)
vectors_batched.extend(vectors)
return np.array(ids_batched), np.array(
vectors_batched
) # shapes (batch_size, n_docs) and (batch_size, n_docs, d)
def retrieve(self, question_hidden_states: np.ndarray, n_docs: int) -> Tuple[np.ndarray, List[dict]]:
"""
Retrieves documents for specified ``question_hidden_states``.
Args:
question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`):
A batch of query vectors to retrieve with.
n_docs (:obj:`int`):
The number of docs retrieved per query.
Return:
retrieved_doc_embeds (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs, dim)`
The retrieval embeddings of the retrieved docs per query.
doc_ids (:obj:`np.ndarray` of shape :obj:`batch_size, n_docs`)
The ids of the documents in the index
doc_dicts (:obj:`List[dict]`):
The retrieved_doc_embeds examples per query.
"""
doc_ids, retrieved_doc_embeds = self._main_retrieve(question_hidden_states, n_docs)
return retrieved_doc_embeds, doc_ids, self.index.get_doc_dicts(doc_ids)
def __call__(
self,
question_input_ids: List[List[int]],
question_hidden_states: np.ndarray,
prefix=None,
n_docs=None,
return_tensors=None,
) -> BatchEncoding:
"""
Retrieves documents for specified :obj:`question_hidden_states`.
Args:
question_input_ids: (:obj:`List[List[int]]`) batch of input ids
question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`:
A batch of query vectors to retrieve with.
prefix: (:obj:`str`, `optional`):
The prefix used by the generator's tokenizer.
n_docs (:obj:`int`, `optional`):
The number of docs retrieved per query.
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"):
If set, will return tensors instead of list of python integers. Acceptable values are:
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
Output:
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
- **context_input_ids** -- List of token ids to be fed to a model.
`What are input IDs? <../glossary.html#input-ids>`__
- **context_attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
:obj:`return_attention_mask=True` or if `"attention_mask"` is in :obj:`self.model_input_names`).
`What are attention masks? <../glossary.html#attention-mask>`__
- **retrieved_doc_embeds** -- List of embeddings of the retrieved documents
- **doc_ids** -- List of ids of the retrieved documents
"""
n_docs = n_docs if n_docs is not None else self.n_docs
prefix = prefix if prefix is not None else self.config.generator.prefix
retrieved_doc_embeds, doc_ids, docs = self.retrieve(question_hidden_states, n_docs)
input_strings = self.question_encoder_tokenizer.batch_decode(question_input_ids, skip_special_tokens=True)
context_input_ids, context_attention_mask = self.postprocess_docs(
docs, input_strings, prefix, n_docs, return_tensors=return_tensors
)
return BatchEncoding(
{
"context_input_ids": context_input_ids,
"context_attention_mask": context_attention_mask,
"retrieved_doc_embeds": retrieved_doc_embeds,
"doc_ids": doc_ids,
},
tensor_type=return_tensors,
)

Просмотреть файл

@ -10,7 +10,7 @@ from distutils.util import strtobool
from io import StringIO
from pathlib import Path
from .file_utils import _tf_available, _torch_available, _torch_tpu_available
from .file_utils import _datasets_available, _faiss_available, _tf_available, _torch_available, _torch_tpu_available
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
@ -161,6 +161,21 @@ def require_torch_and_cuda(test_case):
return test_case
def require_datasets(test_case):
"""Decorator marking a test that requires datasets."""
if not _datasets_available:
test_case = unittest.skip("test requires Datasets")(test_case)
return test_case
def require_faiss(test_case):
"""Decorator marking a test that requires faiss."""
if not _faiss_available:
test_case = unittest.skip("test requires Faiss")(test_case)
return test_case
def get_tests_dir():
"""
returns the full path to the `tests` dir, so that the tests can be invoked from anywhere

Просмотреть файл

@ -26,6 +26,7 @@ from .configuration_auto import (
CamembertConfig,
CTRLConfig,
DistilBertConfig,
DPRConfig,
ElectraConfig,
EncoderDecoderConfig,
FlaubertConfig,
@ -40,6 +41,7 @@ from .configuration_auto import (
MobileBertConfig,
OpenAIGPTConfig,
PegasusConfig,
RagConfig,
ReformerConfig,
RetriBertConfig,
RobertaConfig,
@ -60,6 +62,7 @@ from .tokenization_bertweet import BertweetTokenizer
from .tokenization_camembert import CamembertTokenizer
from .tokenization_ctrl import CTRLTokenizer
from .tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFast
from .tokenization_dpr import DPRQuestionEncoderTokenizer, DPRQuestionEncoderTokenizerFast
from .tokenization_electra import ElectraTokenizer, ElectraTokenizerFast
from .tokenization_flaubert import FlaubertTokenizer
from .tokenization_fsmt import FSMTTokenizer
@ -74,6 +77,7 @@ from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFas
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
from .tokenization_pegasus import PegasusTokenizer
from .tokenization_phobert import PhobertTokenizer
from .tokenization_rag import RagTokenizer
from .tokenization_reformer import ReformerTokenizer
from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
@ -110,6 +114,7 @@ TOKENIZER_MAPPING = OrderedDict(
(FunnelConfig, (FunnelTokenizer, FunnelTokenizerFast)),
(LxmertConfig, (LxmertTokenizer, LxmertTokenizerFast)),
(LayoutLMConfig, (LayoutLMTokenizer, LayoutLMTokenizerFast)),
(DPRConfig, (DPRQuestionEncoderTokenizer, DPRQuestionEncoderTokenizerFast)),
(BertConfig, (BertTokenizer, BertTokenizerFast)),
(OpenAIGPTConfig, (OpenAIGPTTokenizer, OpenAIGPTTokenizerFast)),
(GPT2Config, (GPT2Tokenizer, GPT2TokenizerFast)),
@ -121,6 +126,7 @@ TOKENIZER_MAPPING = OrderedDict(
(FSMTConfig, (FSMTTokenizer, None)),
(BertGenerationConfig, (BertGenerationTokenizer, None)),
(LayoutLMConfig, (LayoutLMTokenizer, None)),
(RagConfig, (RagTokenizer, None)),
]
)

Просмотреть файл

@ -0,0 +1,160 @@
# coding=utf-8
# Copyright 2020, The RAG Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for RAG."""
import os
from typing import List, Optional
from .configuration_rag import RagConfig
from .tokenization_utils_base import BatchEncoding
from .utils import logging
logger = logging.get_logger(__name__)
class RagTokenizer:
def __init__(self, question_encoder, generator):
self.question_encoder = question_encoder
self.generator = generator
def save_pretrained(self, save_directory):
if os.path.isfile(save_directory):
raise ValueError("Provided path ({}) should be a directory, not a file".format(save_directory))
os.makedirs(save_directory, exist_ok=True)
question_encoder_path = os.path.join(save_directory, "question_encoder_tokenizer")
generator_path = os.path.join(save_directory, "generator_tokenizer")
self.question_encoder.save_pretrained(question_encoder_path)
self.generator.save_pretrained(generator_path)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
# dynamically import AutoTokenizer
from .tokenization_auto import AutoTokenizer
config = kwargs.pop("config", None)
if config is None:
config = RagConfig.from_pretrained(pretrained_model_name_or_path)
question_encoder_path = os.path.join(pretrained_model_name_or_path, "question_encoder_tokenizer")
generator_path = os.path.join(pretrained_model_name_or_path, "generator_tokenizer")
question_encoder = AutoTokenizer.from_pretrained(question_encoder_path, config=config.question_encoder)
generator = AutoTokenizer.from_pretrained(generator_path, config=config.generator)
return cls(question_encoder=question_encoder, generator=generator)
def __call__(self, *args, **kwargs):
return self.question_encoder(*args, **kwargs)
def batch_decode(self, *args, **kwargs):
return self.generator.batch_decode(*args, **kwargs)
def prepare_seq2seq_batch(
self,
src_texts: List[str],
tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None,
max_target_length: Optional[int] = None,
padding: str = "longest",
return_tensors: str = "np",
truncation=True,
**kwargs,
) -> BatchEncoding:
r"""
Prepare a batch that can be passed directly to an instance of :class:`~transformers.RagModel`.
Args:
src_texts: (:obj:`List[str]`):
List of documents to summarize or source language texts.
tgt_texts: (:obj:`List[str]`, `optional`):
List of summaries or target language texts.
max_length (:obj:`int`, `optional`):
Controls the maximum length for encoder inputs (documents to summarize or source language texts).
If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum
length is required by one of the truncation/padding parameters. If the model has no specific maximum
input length (like XLNet) truncation/padding to a maximum length will be deactivated.
max_target_length (:obj:`int`, `optional`):
Controls the maximum length of decoder inputs (target language texts or summaries).
If left unset or set to :obj:`None`, this will use the max_length value.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
Activates and controls padding. Accepts the following values:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"):
If set, will return tensors instead of list of python integers. Acceptable values are:
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
Activates and controls truncation. Accepts the following values:
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
provided. This will truncate token by token, removing a token from the longest sequence in the pair
if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
to the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
sequence lengths greater than the model maximum admissible input size).
**kwargs:
Additional keyword arguments passed along to :obj:`self.__call__`.
Returns:
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
- **input_ids** -- List of token ids to be fed to the encoder.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
- **labels** -- List of token ids for tgt_texts
The full set of keys ``[input_ids, attention_mask, labels]``,
will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys.
"""
if max_length is None:
max_length = self.question_encoder.model_max_length
model_inputs: BatchEncoding = self.question_encoder(
src_texts,
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
padding=padding,
truncation=truncation,
**kwargs,
)
if tgt_texts is None:
return model_inputs
# Process tgt_texts
if max_target_length is None:
max_target_length = self.generator.model_max_length
labels = self.generator(
tgt_texts,
add_special_tokens=True,
return_tensors=return_tensors,
padding=padding,
max_length=max_target_length,
truncation=truncation,
**kwargs,
)["input_ids"]
model_inputs["labels"] = labels
return model_inputs

885
tests/test_modeling_rag.py Normal file
Просмотреть файл

@ -0,0 +1,885 @@
# coding=utf-8
# Copyright 2020, The RAG Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import shutil
import tempfile
import unittest
from unittest.mock import patch
import numpy as np
from transformers.file_utils import cached_property, is_datasets_available, is_faiss_available, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.tokenization_bart import BartTokenizer
from transformers.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
from transformers.tokenization_dpr import DPRQuestionEncoderTokenizer
from transformers.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
from transformers.tokenization_t5 import T5Tokenizer
from .test_modeling_bart import ModelTester as BartModelTester
from .test_modeling_dpr import DPRModelTester
from .test_modeling_t5 import T5ModelTester
TOLERANCE = 1e-3
T5_SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
if is_torch_available() and is_datasets_available() and is_faiss_available():
import torch
from datasets import Dataset
import faiss
from transformers import (
AutoConfig,
AutoModel,
AutoModelForSeq2SeqLM,
RagConfig,
RagModel,
RagRetriever,
RagSequenceForGeneration,
RagTokenForGeneration,
)
from transformers.modeling_outputs import BaseModelOutput
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
if a is None and b is None:
return True
try:
if torch.allclose(a, b, atol=atol):
return True
raise
except Exception:
msg = "{} != {}".format(a, b)
if prefix:
msg = prefix + ": " + msg
raise AssertionError(msg)
def require_retrieval(test_case):
"""
Decorator marking a test that requires a set of dependencies necessary for pefrorm retrieval with
:class:`~transformers.RagRetriever`.
These tests are skipped when respective libraries are not installed.
"""
if not (is_torch_available() and is_datasets_available() and is_faiss_available()):
test_case = unittest.skip("test requires PyTorch")(test_case)
return test_case
@require_torch
@require_retrieval
class RagTestMixin:
all_model_classes = (
(RagModel, RagTokenForGeneration, RagSequenceForGeneration)
if is_torch_available() and is_datasets_available() and is_faiss_available()
else ()
)
retrieval_vector_size = 32
n_docs = 2
max_combined_length = 16
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
# DPR tok
vocab_tokens = [
"[UNK]",
"[CLS]",
"[SEP]",
"[PAD]",
"[MASK]",
"want",
"##want",
"##ed",
"wa",
"un",
"runn",
"##ing",
",",
"low",
"lowest",
]
dpr_tokenizer_path = os.path.join(self.tmpdirname, "dpr_tokenizer")
os.makedirs(dpr_tokenizer_path, exist_ok=True)
self.vocab_file = os.path.join(dpr_tokenizer_path, DPR_VOCAB_FILES_NAMES["vocab_file"])
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
# BART tok
vocab = [
"l",
"o",
"w",
"e",
"r",
"s",
"t",
"i",
"d",
"n",
"\u0120",
"\u0120l",
"\u0120n",
"\u0120lo",
"\u0120low",
"er",
"\u0120lowest",
"\u0120newer",
"\u0120wider",
"<unk>",
]
vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
self.special_tokens_map = {"unk_token": "<unk>"}
bart_tokenizer_path = os.path.join(self.tmpdirname, "bart_tokenizer")
os.makedirs(bart_tokenizer_path, exist_ok=True)
self.vocab_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["vocab_file"])
self.merges_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["merges_file"])
with open(self.vocab_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(vocab_tokens) + "\n")
with open(self.merges_file, "w", encoding="utf-8") as fp:
fp.write("\n".join(merges))
t5_tokenizer = T5Tokenizer(T5_SAMPLE_VOCAB)
t5_tokenizer_path = os.path.join(self.tmpdirname, "t5_tokenizer")
t5_tokenizer.save_pretrained(t5_tokenizer_path)
@cached_property
def dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer:
return DPRQuestionEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer"))
@cached_property
def bart_tokenizer(self) -> BartTokenizer:
return BartTokenizer.from_pretrained(os.path.join(self.tmpdirname, "bart_tokenizer"))
@cached_property
def t5_tokenizer(self) -> BartTokenizer:
return T5Tokenizer.from_pretrained(os.path.join(self.tmpdirname, "t5_tokenizer"))
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def get_retriever(self, config):
dataset = Dataset.from_dict(
{
"id": ["0", "1"],
"text": ["foo", "bar"],
"title": ["Foo", "Bar"],
"embeddings": [np.ones(self.retrieval_vector_size), 2 * np.ones(self.retrieval_vector_size)],
}
)
dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT)
tokenizer = self.bart_tokenizer if config.generator.model_type == "bart" else self.t5_tokenizer
with patch("transformers.retrieval_rag.load_dataset") as mock_load_dataset:
mock_load_dataset.return_value = dataset
retriever = RagRetriever(
config,
question_encoder_tokenizer=self.dpr_tokenizer,
generator_tokenizer=tokenizer,
)
return retriever
def check_model_with_retriever(
self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs
):
self.assertIsNotNone(config.question_encoder)
self.assertIsNotNone(config.generator)
for model_class in self.all_model_classes:
model = model_class(config, retriever=self.get_retriever(config)).to(torch_device)
model.eval()
self.assertTrue(model.config.is_encoder_decoder)
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
)
# logits
self.assertEqual(
outputs.logits.shape,
(self.n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size),
)
# generator encoder last hidden states
self.assertEqual(
outputs.generator_enc_last_hidden_state.shape,
(self.n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size),
)
# doc scores
self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs))
def check_model_generate(
self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs
):
self.assertIsNotNone(config.question_encoder)
self.assertIsNotNone(config.generator)
for model_class in self.all_model_classes[1:]:
model = model_class(config, retriever=self.get_retriever(config)).to(torch_device)
model.eval()
self.assertTrue(model.config.is_encoder_decoder)
outputs = model.generate(
input_ids=input_ids,
num_beams=2,
num_return_sequences=2,
decoder_start_token_id=config.generator.eos_token_id,
)
self.assertIsNotNone(outputs)
def check_model_without_retriever(
self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs
):
self.assertIsNotNone(config.question_encoder)
self.assertIsNotNone(config.generator)
retriever = self.get_retriever(config)
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
model.eval()
self.assertTrue(model.config.is_encoder_decoder)
question_hidden_states = model.question_encoder(input_ids, attention_mask=attention_mask)[0]
out = retriever(
input_ids,
question_hidden_states.cpu().detach().to(torch.float32).numpy(),
prefix=config.generator.prefix,
return_tensors="pt",
)
context_input_ids, context_attention_mask, retrieved_doc_embeds = (
out["context_input_ids"],
out["context_attention_mask"],
out["retrieved_doc_embeds"],
)
# cast
retrieved_doc_embeds = retrieved_doc_embeds.to(question_hidden_states)
context_input_ids = context_input_ids.to(input_ids)
context_attention_mask = context_attention_mask.to(input_ids)
# compute doc_scores
doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), retrieved_doc_embeds.transpose(1, 2)).squeeze(
1
)
outputs = model(
context_input_ids=context_input_ids,
context_attention_mask=context_attention_mask,
doc_scores=doc_scores,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
)
# logits
self.assertEqual(
outputs.logits.shape,
(self.n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size),
)
# generator encoder last hidden states
self.assertEqual(
outputs.generator_enc_last_hidden_state.shape,
(self.n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size),
)
# doc scores
self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs))
def check_model_with_encoder_outputs(
self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs
):
self.assertIsNotNone(config.question_encoder)
self.assertIsNotNone(config.generator)
for model_class in self.all_model_classes:
model = model_class(config, retriever=self.get_retriever(config)).to(torch_device)
model.eval()
self.assertTrue(model.config.is_encoder_decoder)
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
)
encoder_outputs = BaseModelOutput(outputs.generator_enc_last_hidden_state)
# run only generator
outputs = model(
encoder_outputs=encoder_outputs,
doc_scores=outputs.doc_scores,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
)
# logits
self.assertEqual(
outputs.logits.shape,
(self.n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size),
)
# generator encoder last hidden states
self.assertEqual(
outputs.generator_enc_last_hidden_state.shape,
(self.n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size),
)
# doc scores
self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs))
def test_model_with_retriever(self):
inputs_dict = self.config_and_inputs
self.check_model_with_retriever(**inputs_dict)
def test_model_without_retriever(self):
inputs_dict = self.config_and_inputs
self.check_model_without_retriever(**inputs_dict)
def test_model_with_encoder_outputs(self):
inputs_dict = self.config_and_inputs
self.check_model_with_encoder_outputs(**inputs_dict)
def test_model_generate(self):
inputs_dict = self.config_and_inputs
self.check_model_generate(**inputs_dict)
@require_torch
@require_retrieval
class RagDPRBartTest(RagTestMixin, unittest.TestCase):
@cached_property
def config_and_inputs(self):
question_encoder_tester = DPRModelTester(self)
dpr_config_and_inputs = question_encoder_tester.prepare_config_and_inputs()
generator_tester = BartModelTester(self)
bart_config_and_inputs = generator_tester.prepare_config_and_inputs_for_common()
(question_encoder_config, input_ids, _, input_mask, _, _, _) = dpr_config_and_inputs
(generator_config, bart_inputs_dict) = bart_config_and_inputs
decoder_input_ids, decoder_attention_mask = bart_inputs_dict["input_ids"], bart_inputs_dict["attention_mask"]
config = RagConfig.from_question_encoder_generator_configs(
question_encoder_config,
generator_config,
n_docs=self.n_docs,
retrieval_vector_size=self.retrieval_vector_size,
max_combined_length=self.max_combined_length,
use_cache=False,
)
return {
"config": config,
"input_ids": input_ids,
"attention_mask": input_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
}
@require_torch
@require_retrieval
class RagDPRT5Test(RagTestMixin, unittest.TestCase):
@cached_property
def config_and_inputs(self):
question_encoder_tester = DPRModelTester(self)
dpr_config_and_inputs = question_encoder_tester.prepare_config_and_inputs()
generator_tester = T5ModelTester(self, vocab_size=1100, n_positions=30)
t5_config_and_inputs = generator_tester.prepare_config_and_inputs()
(question_encoder_config, input_ids, _, input_mask, _, _, _) = dpr_config_and_inputs
# import ipdb; ipdb.set_trace()
(generator_config, _, decoder_input_ids, _, decoder_attention_mask, _) = t5_config_and_inputs
config = RagConfig.from_question_encoder_generator_configs(
question_encoder_config,
generator_config,
n_docs=self.n_docs,
retrieval_vector_size=self.retrieval_vector_size,
max_combined_length=self.max_combined_length,
use_cache=False,
)
return {
"config": config,
"input_ids": input_ids,
"attention_mask": input_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
}
@require_torch
@require_retrieval
class RagModelIntegrationTests(unittest.TestCase):
@cached_property
def sequence_model(self):
return (
RagSequenceForGeneration.from_pretrained_question_encoder_generator(
"facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large-cnn"
)
.to(torch_device)
.eval()
)
@cached_property
def token_model(self):
return (
RagTokenForGeneration.from_pretrained_question_encoder_generator(
"facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large-cnn"
)
.to(torch_device)
.eval()
)
def get_rag_config(self):
question_encoder_config = AutoConfig.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
generator_config = AutoConfig.from_pretrained("facebook/bart-large-cnn")
return RagConfig.from_question_encoder_generator_configs(
question_encoder_config,
generator_config,
bos_token_id=0,
decoder_start_token_id=2,
eos_token_id=2,
is_encoder_decoder=True,
pad_token_id=1,
vocab_size=50264,
title_sep=" / ",
doc_sep=" // ",
n_docs=5,
max_combined_length=300,
dataset="wiki_dpr",
dataset_split="train",
index_name="exact",
index_path=None,
use_dummy_dataset=True,
retrieval_vector_size=768,
retrieval_batch_size=8,
)
@slow
def test_rag_sequence_inference(self):
rag_config = self.get_rag_config()
rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
"facebook/dpr-question_encoder-single-nq-base"
)
rag_retriever = RagRetriever(
rag_config,
question_encoder_tokenizer=rag_question_encoder_tokenizer,
generator_tokenizer=rag_decoder_tokenizer,
)
rag_sequence = self.sequence_model
rag_sequence.set_retriever(rag_retriever)
input_ids = rag_question_encoder_tokenizer(
"who sings does he love me with reba", return_tensors="pt"
).input_ids
decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="pt").input_ids
input_ids = input_ids.to(torch_device)
decoder_input_ids = decoder_input_ids.to(torch_device)
with torch.no_grad():
output = rag_sequence(
input_ids,
labels=decoder_input_ids,
)
expected_shape = torch.Size([5, 5, 50264])
self.assertEqual(output.logits.shape, expected_shape)
expected_doc_scores = torch.tensor([[75.0286, 74.4998, 74.0804, 74.0306, 73.9504]]).to(torch_device)
_assert_tensors_equal(expected_doc_scores, output.doc_scores, atol=TOLERANCE)
expected_loss = torch.tensor([38.7446]).to(torch_device)
_assert_tensors_equal(expected_loss, output.loss, atol=TOLERANCE)
@slow
def test_rag_token_inference(self):
rag_config = self.get_rag_config()
rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
"facebook/dpr-question_encoder-single-nq-base"
)
rag_retriever = RagRetriever(
rag_config,
question_encoder_tokenizer=rag_question_encoder_tokenizer,
generator_tokenizer=rag_decoder_tokenizer,
)
rag_token = self.token_model
rag_token.set_retriever(rag_retriever)
input_ids = rag_question_encoder_tokenizer(
"who sings does he love me with reba", return_tensors="pt"
).input_ids
decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="pt").input_ids
input_ids = input_ids.to(torch_device)
decoder_input_ids = decoder_input_ids.to(torch_device)
with torch.no_grad():
output = rag_token(
input_ids,
labels=decoder_input_ids,
)
expected_shape = torch.Size([5, 5, 50264])
self.assertEqual(output.logits.shape, expected_shape)
expected_doc_scores = torch.tensor([[75.0286, 74.4998, 74.0804, 74.0306, 73.9504]]).to(torch_device)
_assert_tensors_equal(expected_doc_scores, output.doc_scores, atol=TOLERANCE)
expected_loss = torch.tensor([38.7045]).to(torch_device)
_assert_tensors_equal(expected_loss, output.loss, atol=TOLERANCE)
@slow
def test_rag_token_generate_beam(self):
rag_config = self.get_rag_config()
rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
"facebook/dpr-question_encoder-single-nq-base"
)
rag_retriever = RagRetriever(
rag_config,
question_encoder_tokenizer=rag_question_encoder_tokenizer,
generator_tokenizer=rag_decoder_tokenizer,
)
rag_token = self.token_model
rag_token.set_retriever(rag_retriever)
input_ids = rag_question_encoder_tokenizer(
"who sings does he love me with reba", return_tensors="pt"
).input_ids
input_ids = input_ids.to(torch_device)
output_ids = rag_token.generate(
input_ids,
decoder_start_token_id=rag_token.generator.config.decoder_start_token_id,
num_beams=2,
num_return_sequences=2,
)
# sequence generate test
output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True)
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
# Expected outputs as given by model at integration time.
EXPECTED_OUTPUT_TEXT_1 = "The songwriting credits are credited to ABBA"
EXPECTED_OUTPUT_TEXT_2 = 'The songwriting credits are credited to "B'
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
@slow
def test_rag_token_generate_batch(self):
rag_config = self.get_rag_config()
rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
"facebook/dpr-question_encoder-single-nq-base"
)
rag_retriever = RagRetriever(
rag_config,
question_encoder_tokenizer=rag_question_encoder_tokenizer,
generator_tokenizer=rag_decoder_tokenizer,
)
rag_token = self.token_model
rag_token.set_retriever(rag_retriever)
questions = [
"who sings does he love me with reba",
"how many pages is invisible man by ralph ellison",
]
input_ids = rag_question_encoder_tokenizer.batch_encode_plus(
questions,
return_tensors="pt",
padding=True,
truncation=True,
).input_ids
input_ids = input_ids.to(torch_device)
output_ids = rag_token.generate(
input_ids,
decoder_start_token_id=rag_token.generator.config.decoder_start_token_id,
num_beams=4,
num_return_sequences=1,
max_length=10,
)
# sequence generate test
output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True)
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
# Expected outputs as given by model at integration time.
EXPECTED_OUTPUT_TEXT_1 = '"People Need Love" is the'
EXPECTED_OUTPUT_TEXT_2 = '"How many pages is invisible man'
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
@slow
def test_rag_sequence_generate_batch(self):
# IMPORTAN: This test fails on GPU, but is fine on CPU -> beam search is very sensible
rag_config = self.get_rag_config()
rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
"facebook/dpr-question_encoder-single-nq-base"
)
rag_retriever = RagRetriever(
rag_config,
question_encoder_tokenizer=rag_question_encoder_tokenizer,
generator_tokenizer=rag_decoder_tokenizer,
)
rag_sequence = self.sequence_model
rag_sequence.set_retriever(rag_retriever)
questions = [
"who sings does he love me with reba",
"how many pages is invisible man by ralph ellison",
]
input_ids = rag_question_encoder_tokenizer.batch_encode_plus(
questions,
return_tensors="pt",
padding=True,
truncation=True,
).input_ids
input_ids = input_ids.to(torch_device)
output_ids = rag_sequence.generate(
input_ids,
decoder_start_token_id=rag_sequence.generator.config.decoder_start_token_id,
num_beams=4,
num_return_sequences=1,
max_length=10,
)
# sequence generate test
output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True)
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
# Expected outputs as given by model at integration time.
EXPECTED_OUTPUT_TEXT_1 = '"I Know Him So Well"'
EXPECTED_OUTPUT_TEXT_2 = '"Howl" chronicles the'
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
@slow
def test_rag_sequence_generate_beam(self):
rag_config = self.get_rag_config()
rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
"facebook/dpr-question_encoder-single-nq-base"
)
rag_retriever = RagRetriever(
rag_config,
question_encoder_tokenizer=rag_question_encoder_tokenizer,
generator_tokenizer=rag_decoder_tokenizer,
)
rag_token = self.sequence_model
rag_token.set_retriever(rag_retriever)
input_ids = rag_question_encoder_tokenizer(
"who sings does he love me with reba", return_tensors="pt"
).input_ids
input_ids = input_ids.to(torch_device)
output_ids = rag_token.generate(
input_ids,
decoder_start_token_id=rag_token.generator.config.decoder_start_token_id,
num_beams=2,
num_return_sequences=2,
)
# sequence generate test
output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True)
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
# Expected outputs as given by model at integration time.
EXPECTED_OUTPUT_TEXT_1 = """ ABBA / small label like Playboy Records did not have the distribution resources to meet the demand for the single from retailers and radio programmers. The foursome decided to record their first album together in late 1972, and sessions began on 26 September 1972. The women shared lead vocals on "Nina, Pretty Ballerina" that day."""
EXPECTED_OUTPUT_TEXT_2 = """ ABBA / small label like Playboy Records did not have the distribution resources to meet the demand for the single from retailers and radio programmers. The foursome decided to record their first album together in late 1972, and sessions began on 26 September 1972. The women shared lead vocals on "Nina, Pretty Ballerina" (a top ten hit in Austria)"""
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
@require_torch
@require_retrieval
class RagModelSaveLoadTests(unittest.TestCase):
def get_rag_config(self):
question_encoder_config = AutoConfig.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
generator_config = AutoConfig.from_pretrained("facebook/bart-large-cnn")
return RagConfig.from_question_encoder_generator_configs(
question_encoder_config,
generator_config,
bos_token_id=0,
decoder_start_token_id=2,
eos_token_id=2,
is_encoder_decoder=True,
pad_token_id=1,
vocab_size=50264,
title_sep=" / ",
doc_sep=" // ",
n_docs=5,
max_combined_length=300,
dataset="wiki_dpr",
dataset_split="train",
index_name="exact",
index_path=None,
use_dummy_dataset=True,
retrieval_vector_size=768,
retrieval_batch_size=8,
)
@slow
def test_rag_sequence_from_pretrained(self):
rag_config = self.get_rag_config()
rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
"facebook/dpr-question_encoder-single-nq-base"
)
rag_retriever = RagRetriever(
rag_config,
question_encoder_tokenizer=rag_question_encoder_tokenizer,
generator_tokenizer=rag_decoder_tokenizer,
)
input_ids = rag_question_encoder_tokenizer(
"who sings does he love me with reba", return_tensors="pt"
).input_ids
decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="pt").input_ids
input_ids = input_ids.to(torch_device)
decoder_input_ids = decoder_input_ids.to(torch_device)
with tempfile.TemporaryDirectory() as tmp_dirname:
rag_sequence = RagSequenceForGeneration.from_pretrained_question_encoder_generator(
"facebook/dpr-question_encoder-single-nq-base",
"facebook/bart-large-cnn",
retriever=rag_retriever,
config=rag_config,
).to(torch_device)
# check that the from pretrained methods work
rag_sequence.save_pretrained(tmp_dirname)
rag_sequence.from_pretrained(tmp_dirname, retriever=rag_retriever)
rag_sequence.to(torch_device)
with torch.no_grad():
output = rag_sequence(
input_ids,
labels=decoder_input_ids,
)
loss_pretrained = output.loss
del rag_sequence
question_encoder = AutoModel.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
generator = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
rag_sequence = RagSequenceForGeneration(
config=rag_config, question_encoder=question_encoder, generator=generator, retriever=rag_retriever
)
rag_sequence.to(torch_device)
with torch.no_grad():
output = rag_sequence(
input_ids,
labels=decoder_input_ids,
)
loss_init = output.loss
self.assertAlmostEqual(loss_pretrained.item(), loss_init.item(), places=4)
@slow
def test_rag_token_from_pretrained(self):
rag_config = self.get_rag_config()
rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
"facebook/dpr-question_encoder-single-nq-base"
)
rag_retriever = RagRetriever(
rag_config,
question_encoder_tokenizer=rag_question_encoder_tokenizer,
generator_tokenizer=rag_decoder_tokenizer,
)
input_ids = rag_question_encoder_tokenizer(
"who sings does he love me with reba", return_tensors="pt"
).input_ids
decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="pt").input_ids
input_ids = input_ids.to(torch_device)
decoder_input_ids = decoder_input_ids.to(torch_device)
with tempfile.TemporaryDirectory() as tmp_dirname:
rag_token = RagTokenForGeneration.from_pretrained_question_encoder_generator(
"facebook/dpr-question_encoder-single-nq-base",
"facebook/bart-large-cnn",
retriever=rag_retriever,
config=rag_config,
).to(torch_device)
# check that the from pretrained methods work
rag_token.save_pretrained(tmp_dirname)
rag_token.from_pretrained(tmp_dirname, retriever=rag_retriever)
rag_token.to(torch_device)
with torch.no_grad():
output = rag_token(
input_ids,
labels=decoder_input_ids,
)
loss_pretrained = output.loss
del rag_token
question_encoder = AutoModel.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
generator = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
rag_token = RagTokenForGeneration(
config=rag_config, question_encoder=question_encoder, generator=generator, retriever=rag_retriever
)
rag_token.to(torch_device)
with torch.no_grad():
output = rag_token(
input_ids,
labels=decoder_input_ids,
)
loss_init = output.loss
self.assertAlmostEqual(loss_pretrained.item(), loss_init.item(), places=4)

Просмотреть файл

@ -35,28 +35,53 @@ if is_torch_available():
class T5ModelTester:
def __init__(self, parent):
def __init__(
self,
parent,
vocab_size=99,
n_positions=14,
batch_size=13,
encoder_seq_length=7,
decoder_seq_length=9,
# For common tests
seq_length=7,
is_training=True,
use_attention_mask=True,
use_labels=True,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
d_ff=37,
relative_attention_num_buckets=8,
dropout_rate=0.1,
initializer_factor=0.002,
eos_token_id=1,
pad_token_id=0,
decoder_start_token_id=0,
scope=None,
):
self.parent = parent
self.batch_size = 13
self.encoder_seq_length = 7
self.decoder_seq_length = 9
self.batch_size = batch_size
self.encoder_seq_length = encoder_seq_length
self.decoder_seq_length = decoder_seq_length
# For common tests
self.seq_length = self.decoder_seq_length
self.is_training = True
self.use_attention_mask = True
self.use_labels = True
self.vocab_size = 99
self.n_positions = 14
self.hidden_size = 32
self.num_hidden_layers = 5
self.num_attention_heads = 4
self.d_ff = 37
self.relative_attention_num_buckets = 8
self.dropout_rate = 0.1
self.initializer_factor = 0.002
self.eos_token_id = 1
self.pad_token_id = 0
self.decoder_start_token_id = 0
self.is_training = is_training
self.use_attention_mask = use_attention_mask
self.use_labels = use_labels
self.vocab_size = vocab_size
self.n_positions = n_positions
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.d_ff = d_ff
self.relative_attention_num_buckets = relative_attention_num_buckets
self.dropout_rate = dropout_rate
self.initializer_factor = initializer_factor
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.decoder_start_token_id = decoder_start_token_id
self.scope = None
def prepare_config_and_inputs(self):

223
tests/test_retrieval_rag.py Normal file
Просмотреть файл

@ -0,0 +1,223 @@
import json
import os
import pickle
import shutil
import tempfile
from unittest import TestCase
from unittest.mock import patch
import numpy as np
from datasets import Dataset
import faiss
from transformers.configuration_bart import BartConfig
from transformers.configuration_dpr import DPRConfig
from transformers.configuration_rag import RagConfig
from transformers.retrieval_rag import RagRetriever
from transformers.testing_utils import require_datasets, require_faiss, require_torch
from transformers.tokenization_bart import BartTokenizer
from transformers.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
from transformers.tokenization_dpr import DPRQuestionEncoderTokenizer
from transformers.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
@require_faiss
@require_datasets
class RagRetrieverTest(TestCase):
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
self.retrieval_vector_size = 8
# DPR tok
vocab_tokens = [
"[UNK]",
"[CLS]",
"[SEP]",
"[PAD]",
"[MASK]",
"want",
"##want",
"##ed",
"wa",
"un",
"runn",
"##ing",
",",
"low",
"lowest",
]
dpr_tokenizer_path = os.path.join(self.tmpdirname, "dpr_tokenizer")
os.makedirs(dpr_tokenizer_path, exist_ok=True)
self.vocab_file = os.path.join(dpr_tokenizer_path, DPR_VOCAB_FILES_NAMES["vocab_file"])
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
# BART tok
vocab = [
"l",
"o",
"w",
"e",
"r",
"s",
"t",
"i",
"d",
"n",
"\u0120",
"\u0120l",
"\u0120n",
"\u0120lo",
"\u0120low",
"er",
"\u0120lowest",
"\u0120newer",
"\u0120wider",
"<unk>",
]
vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
self.special_tokens_map = {"unk_token": "<unk>"}
bart_tokenizer_path = os.path.join(self.tmpdirname, "bart_tokenizer")
os.makedirs(bart_tokenizer_path, exist_ok=True)
self.vocab_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["vocab_file"])
self.merges_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["merges_file"])
with open(self.vocab_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(vocab_tokens) + "\n")
with open(self.merges_file, "w", encoding="utf-8") as fp:
fp.write("\n".join(merges))
def get_dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer:
return DPRQuestionEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer"))
def get_bart_tokenizer(self) -> BartTokenizer:
return BartTokenizer.from_pretrained(os.path.join(self.tmpdirname, "bart_tokenizer"))
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def get_dummy_hf_index_retriever(self):
dataset = Dataset.from_dict(
{
"id": ["0", "1"],
"text": ["foo", "bar"],
"title": ["Foo", "Bar"],
"embeddings": [np.ones(self.retrieval_vector_size), 2 * np.ones(self.retrieval_vector_size)],
}
)
dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT)
config = RagConfig(
retrieval_vector_size=self.retrieval_vector_size,
question_encoder=DPRConfig().to_dict(),
generator=BartConfig().to_dict(),
)
with patch("transformers.retrieval_rag.load_dataset") as mock_load_dataset:
mock_load_dataset.return_value = dataset
retriever = RagRetriever(
config,
question_encoder_tokenizer=self.get_dpr_tokenizer(),
generator_tokenizer=self.get_bart_tokenizer(),
)
return retriever
def get_dummy_legacy_index_retriever(self):
dataset = Dataset.from_dict(
{
"id": ["0", "1"],
"text": ["foo", "bar"],
"title": ["Foo", "Bar"],
"embeddings": [np.ones(self.retrieval_vector_size + 1), 2 * np.ones(self.retrieval_vector_size + 1)],
}
)
dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT)
index_file_name = os.path.join(self.tmpdirname, "hf_bert_base.hnswSQ8_correct_phi_128.c_index")
dataset.save_faiss_index("embeddings", index_file_name + ".index.dpr")
pickle.dump(dataset["id"], open(index_file_name + ".index_meta.dpr", "wb"))
passages_file_name = os.path.join(self.tmpdirname, "psgs_w100.tsv.pkl")
passages = {sample["id"]: [sample["text"], sample["title"]] for sample in dataset}
pickle.dump(passages, open(passages_file_name, "wb"))
config = RagConfig(
retrieval_vector_size=self.retrieval_vector_size,
question_encoder=DPRConfig().to_dict(),
generator=BartConfig().to_dict(),
index_name="legacy",
index_path=self.tmpdirname,
passages_path=self.tmpdirname,
)
retriever = RagRetriever(
config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer()
)
return retriever
def test_hf_index_retriever_retrieve(self):
n_docs = 1
retriever = self.get_dummy_hf_index_retriever()
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
self.assertEqual(len(doc_dicts), 2)
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
self.assertListEqual(list(doc_ids), [1, 0])
def test_legacy_index_retriever_retrieve(self):
n_docs = 1
retriever = self.get_dummy_legacy_index_retriever()
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
self.assertEqual(len(doc_dicts), 2)
self.assertEqual(sorted(doc_dicts[0]), ["text", "title"])
self.assertEqual(len(doc_dicts[0]["text"]), n_docs)
self.assertEqual(doc_dicts[0]["text"][0], "bar") # max inner product is reached with second doc
self.assertEqual(doc_dicts[1]["text"][0], "foo") # max inner product is reached with first doc
self.assertListEqual(list(doc_ids), [1, 0])
@require_torch
def test_hf_index_retriever_call(self):
import torch
n_docs = 1
retriever = self.get_dummy_hf_index_retriever()
question_input_ids = [[5, 7], [10, 11]]
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
out = retriever(question_input_ids, hidden_states, prefix=retriever.config.generator.prefix, n_docs=n_docs)
context_input_ids, context_attention_mask, retrieved_doc_embeds = (
out["context_input_ids"],
out["context_attention_mask"],
out["retrieved_doc_embeds"],
)
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
self.assertIsInstance(context_input_ids, list)
self.assertIsInstance(context_attention_mask, list)
self.assertIsInstance(retrieved_doc_embeds, np.ndarray)
out = retriever(
question_input_ids,
hidden_states,
prefix=retriever.config.generator.prefix,
n_docs=n_docs,
return_tensors="pt",
)
context_input_ids, context_attention_mask, retrieved_doc_embeds, doc_ids = ( # noqa: F841
out["context_input_ids"],
out["context_attention_mask"],
out["retrieved_doc_embeds"],
out["doc_ids"],
)
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
self.assertIsInstance(context_input_ids, torch.Tensor)
self.assertIsInstance(context_attention_mask, torch.Tensor)
self.assertIsInstance(retrieved_doc_embeds, torch.Tensor)

Просмотреть файл

@ -0,0 +1,110 @@
import json
import os
import shutil
import tempfile
from unittest import TestCase
from transformers.configuration_bart import BartConfig
from transformers.configuration_dpr import DPRConfig
from transformers.file_utils import is_datasets_available, is_faiss_available, is_torch_available
from transformers.testing_utils import require_datasets, require_faiss, require_torch
from transformers.tokenization_bart import BartTokenizer
from transformers.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
from transformers.tokenization_dpr import DPRQuestionEncoderTokenizer
from transformers.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
if is_torch_available() and is_datasets_available() and is_faiss_available():
from transformers.configuration_rag import RagConfig
from transformers.tokenization_rag import RagTokenizer
@require_faiss
@require_datasets
@require_torch
class RagTokenizerTest(TestCase):
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
self.retrieval_vector_size = 8
# DPR tok
vocab_tokens = [
"[UNK]",
"[CLS]",
"[SEP]",
"[PAD]",
"[MASK]",
"want",
"##want",
"##ed",
"wa",
"un",
"runn",
"##ing",
",",
"low",
"lowest",
]
dpr_tokenizer_path = os.path.join(self.tmpdirname, "dpr_tokenizer")
os.makedirs(dpr_tokenizer_path, exist_ok=True)
self.vocab_file = os.path.join(dpr_tokenizer_path, DPR_VOCAB_FILES_NAMES["vocab_file"])
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
# BART tok
vocab = [
"l",
"o",
"w",
"e",
"r",
"s",
"t",
"i",
"d",
"n",
"\u0120",
"\u0120l",
"\u0120n",
"\u0120lo",
"\u0120low",
"er",
"\u0120lowest",
"\u0120newer",
"\u0120wider",
"<unk>",
]
vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
self.special_tokens_map = {"unk_token": "<unk>"}
bart_tokenizer_path = os.path.join(self.tmpdirname, "bart_tokenizer")
os.makedirs(bart_tokenizer_path, exist_ok=True)
self.vocab_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["vocab_file"])
self.merges_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["merges_file"])
with open(self.vocab_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(vocab_tokens) + "\n")
with open(self.merges_file, "w", encoding="utf-8") as fp:
fp.write("\n".join(merges))
def get_dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer:
return DPRQuestionEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer"))
def get_bart_tokenizer(self) -> BartTokenizer:
return BartTokenizer.from_pretrained(os.path.join(self.tmpdirname, "bart_tokenizer"))
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def test_save_load_pretrained_with_saved_config(self):
save_dir = os.path.join(self.tmpdirname, "rag_tokenizer")
rag_config = RagConfig(question_encoder=DPRConfig().to_dict(), generator=BartConfig().to_dict())
rag_tokenizer = RagTokenizer(question_encoder=self.get_dpr_tokenizer(), generator=self.get_bart_tokenizer())
rag_config.save_pretrained(save_dir)
rag_tokenizer.save_pretrained(save_dir)
new_rag_tokenizer = RagTokenizer.from_pretrained(save_dir, config=rag_config)
self.assertIsInstance(new_rag_tokenizer.question_encoder, DPRQuestionEncoderTokenizer)
self.assertEqual(new_rag_tokenizer.question_encoder.vocab, rag_tokenizer.question_encoder.vocab)
self.assertIsInstance(new_rag_tokenizer.generator, BartTokenizer)
self.assertEqual(new_rag_tokenizer.generator.encoder, rag_tokenizer.generator.encoder)