Reading Notes: Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks
The original paper link is here. This reading note will specifically focus on the model architecture.
LLMs’ ability to access and precisely manipulate knowledge is still limited, and hence on knowledge-intensive tasks, their performance lags behind task-specific architectures. This paper explores a general-purpose fine-tuning recipe for retrieval-augmented generation (RAG) — models which combine pre-trained parametric and non-parametric memory for language generation. The parametric memory of RAG models is a pre-trained seq2seq and the non-parametric memory is a dense vector index of Wikipedia, accessed with a pre-trained neural retriever. We combine these components in a probabilistic model trained end-to-end. The retriever (Dense Passage Retriever, henceforth DPR) provides latent documents conditioned on the input, and the seq2seq model (BART) then conditions on these latent documents together with the input to generate the output. We marginalize the latent documents with a top-K approximation, either on a per-output basis (assuming the same document is responsible for all tokens) or a per-tokn basis (where different documents are responsible for different tokens)
Methods
We explore RAG models, which use the input sequence x to retrieve text documents z and use them as additional context when generating the target sequence y. The models leverage two components
- a retriever p(z|x) with parameter \eta that returns (top-K truncated) distributions over text passages given a query x
- a generator p( y_i|x, z, y_1,…, y_{i-1}) parametrized by \theta that generates a current token based on a context of the previous i-1 tokens y_1,…, y_{i-1}, the original input x and a retrieved passage z.
Models
To train the retriever and generator end-to-end, we treat the retrieved document as a latent variable. We propose two models that marginalize over the latent documents in different ways to produce a distribution over generated text.
RAG-Sequence Model
The model uses the same retrieved document to generate the complete sequence. Concretely, the top K documents are retrieved using the retriever, and the generator produces the output sequence probability for each document, which are then marginalized.
RAG-Token Model
The model uses a different latent document for each target token and marginalize accordingly. This allows the generator to choose content from several documents when producing an answer. Concretely, the top K documents are retrieved using the retriever, and then the generator produces a distribution for the next output token for each document, before marginalizing, and repeating th eprocess with the following output token.
Retriever: Dense Passage Retriever (DPR)
The retrieval component p(z|x) is based on DPR. DPR follows a bi-encoder architecture:
where d(z) is a dense representation of a document produced by a BERT document encoder, and q(x) is a query representation produced by a query encoder, also based on BERT. The basic idea is simple: when a document and a query are more relevant, their distance (as indicated by the cosine similarity) should be closer. The higher the cosine similarity, the closer they are.
Calculating top-k(p( | x)), the list of k documents z with highest prior probability p(z|x), is a Maximum Inner Product Search problem, which can be solved in sub-linear time. We use a pre-trained bi-encoder from DPR to initialize our retriever and to build the document index.
Generator: BART
The generator component p( y_i|x, z, y_1,…, y_{i-1}) can be modeled using any encoder-decoder. We use BART-large, a pre-trained seq2seq transformer with 400M parameters. To combine the input x with the retrieved content z when generating from BART, we simply concatenate them.
Training
We jointly train the retriever and generator components without any direct supervision on what document should be retrieved. Given a fine-tuning training corpus of input/output pairs (x_i, y_i), we minimize the negative marginal log-likelihood of each target using stochastic gradient descent with Adam.
Updating the document endoer BERT during training is costly as it requires the document index to be periodically updated. We do not find this step necessary for strong performance, and keep the document encoder (and index) fixed, only fine-tuning the query encoder BERT and the BART generator.
Decoding
RAG-Sequence: The likelihood p(y|x) does not break into a conventional per-token likelihood, hence we cannot solve it with a single beam search. Instead, we run beam search for each document z, scoring each hypothesis using p(y_i|x, z, y_1,,,,y_{i-1}). This yields a set of hypotheses Y, some of which may not have appeared in the beams of all documents.
- Thorough Decoding: To estimate the probability of an hypothesis y we run an additional forward pass for each document z for which y does not appear in the beam, multiply generator probability with p(z|x) and then sum the probabilities across beams for the marginals.
- Fast Decoding: For longer output sequences, |Y| can become large, requiring many forward passes. For more efficient decoding, we can make a further approximation that p(y | x, z_i)≈0, where y was not generated during beam search from x, z_i. This avoids the need to run additional forward passes once the candidate set Y has been generated.
RAG-Token: To decode, we can just plug the transition probability p’(y_i|x,y_1,…,y_{i-1}) into a standard beam decoder.