Ideal Generative AI vs. Reality
Foundational LLMs have ingested vast amounts of text, enabling their chatbot counterparts to engage in intelligent conversations and perform specific tasks. This democratizes access to comprehensive information, eliminating the need to find the right keywords or select specific sites to read. However, LLMs are prone to producing disjointed responses, often providing the statistically most probable answer you want to hear (flattery), a result inherent to the transformer model. Extracting 100% accurate information from an LLM’s knowledge base doesn’t always yield reliable results.
Chat LLMs are notorious for fabricating citations from scientific articles or court cases that don’t exist. For instance, lawyers suing an airline included citations from non-existent trials. A 2023 study reported that when ChatGPT was asked to include citations, it only provided real references 14% of the time. Fabricating sources, rambling, and providing inaccuracies to appease the prompt are termed hallucinations, a significant hurdle to overcome before AI can be fully adopted and trusted by the masses.
One way to counter LLMs generating false sources or inaccuracies is Retrieval-Augmented Generation (RAG). RAG not only reduces the tendency of LLMs to hallucinate but also offers several other benefits.
These benefits include access to an updated knowledge base, specialization (e.g., providing private data sources), supplying models with information beyond what’s stored in parametric memory (allowing for smaller models), and the ability to follow up with more data from legitimate references.
What is Retrieval-Augmented Generation (RAG)?
Retrieval-Augmented Generation (RAG) is a deep learning architecture implemented in LLMs and transformer networks that retrieves relevant documents or other code snippets and adds them to the context window to provide additional information, helping an LLM generate useful responses. A typical RAG system has two main modules: retrieval and generation.
The primary reference for RAG is a paper by Lewis et al. from Facebook. In the paper, the authors use a pair of BERT-based document encoders to transform queries and documents by embedding the text into a vector format. These embeddings are then used to identify the top-k (usually 5 or 10) documents via Maximum Inner Product Search (MIPS). As the name suggests, MIPS is based on the inner (or dot) product of the encoded vector representations of the query and those from a precomputed vector database for documents used as non-parametric external memory.
As described in the paper by Lewis et al., RAG was designed to make LLMs more effective in knowledge-intensive tasks that « cannot reasonably be expected to be completed by humans without access to an external knowledge source. » Think of taking an open-book versus a closed-book exam, and you’ll get a good idea of how RAG could complement LLM-based systems.
RAG with Hugging Face 🤗 Library
Lewis et al. open-sourced their RAG models on Hugging Face Hub, allowing us to experiment with the same models used in the paper. A new Python 3.8 virtual environment with virtualenv is recommended.
virtualenv my_env --python=python3.8
source my_env/bin/activate
After activating the environment, we can install dependencies using pip: Hugging Face transformers and datasets, Facebook’s FAISS library that RAG uses for vector search, and PyTorch for use as a backend.
pip install transformers
pip install datasets
pip install faiss-cpu==1.8.0
#https://pytorch.org/get-started/locally/ to
#match the pytorch version to your system
pip install torch
Lewis et al. implemented two different versions of RAG: rag-sequence and rag-token. Rag-sequence uses the same retrieved document to augment the generation of an entire sequence, while rag-token can use different snippets for each token. Both versions use the same Hugging Face classes for tokenization and retrieval, and the API is largely the same, but each version has a unique class for generation. These classes are imported from the transformers library.
from transformers import RagTokenizer, RagRetriever
from transformers import RagTokenForGeneration
from transformers import RagSequenceForGeneration
The first time the RagRetriever model with the default « wiki_dpr » dataset is instantiated, it will trigger a substantial download (around 300 GB). If you have a large data drive and want Hugging Face to use it (instead of the default cache folder on your personal drive), you can set a shell variable, HF_DATASETS_CACHE.
# in the shell:
export HF_DATASETS_CACHE="/path/to/data/drive"
# ^^ add to your ~/.bashrc file if you want to set the variable
Ensure the code works before downloading the full wiki_dpr dataset. To avoid the large download until you’re ready, you can pass use_dummy_dataset=True when instantiating the retriever. You will also instantiate a tokenizer to convert strings into integer indices (corresponding to tokens in a vocabulary) and vice versa. Both the sequence and token versions of RAG use the same tokenizer. Rag-sequence (rag-sequence) and rag-token (rag-token) each have fine-tuned (e.g., rag-token-nq) and base versions (e.g., rag-token-base).
tokenizer = RagTokenizer.from_pretrained(
"facebook/rag-token-nq")
token_retriever = RagRetriever.from_pretrained(
"facebook/rag-token-nq",
index_name="compressed",
use_dummy_dataset=False)
sequence_retriever = RagRetriever.from_pretrained(
"facebook/rag-sequence-nq",
index_name="compressed",
use_dummy_dataset=False)
dummy_retriever = RagRetriever.from_pretrained(
"facebook/rag-sequence-nq",
index_name="exact",
use_dummy_dataset=True)
token_model = RagTokenForGeneration.from_pretrained(
"facebook/rag-token-nq",
retriever=token_retriever)
seq_model = RagTokenForGeneration.from_pretrained(
"facebook/rag-sequence-nq",
retriever=seq_retriever)
dummy_model = RagTokenForGeneration.from_pretrained(
"facebook/rag-sequence-nq",
retriever=dummy_retriever)
Once your models are instantiated, you can provide a query, tokenize it, and pass it to the model’s « generate » function. We will compare the results of rag-sequence, rag-token, and RAG using a retriever with the dummy version of the wiki_dpr dataset. Note that these rag models are case-insensitive.
query = "what is the name of the oldest tree on Earth?"
input_dict = tokenizer.prepare_seq2seq_batch(
query, return_tensors="pt")
token_generated = token_model.generate(**input_dict) token_decoded = token_tokenizer.batch_decode(
token_generated, skip_special_tokens=True)
seq_generated = seq_model.generate(**input_dict)
seq_decoded = seq_tokenizer.batch_decode(
seq_generated, skip_special_tokens=True)
dummy_generated = dummy_model.generate(**input_dict)
dummy_decoded = seq_tokenizer.batch_decode(
dummy_generated, skip_special_tokens=True)
print(f"answers to query '{query}': ")
print(f"t rag-sequence-nq: {seq_decoded[0]},"
f" rag-token-nq: {token_decoded[0]},"
f" rag (dummy): {dummy_decoded[0]}")
>> answers to query 'What is the name of the oldest tree on Earth?': Prometheus was the oldest tree discovered until 2012, with its innermost, extant rings exceeding 4862 years of age.
>> rag-sequence-nq: prometheus, rag-token-nq: prometheus, rag (dummy): 4862
In general, rag-token is correct more often than rag-sequence (though both are often correct), and rag-sequence is more often correct than RAG using a retriever with a dummy dataset.
« What kind of context does the retriever provide? » you might wonder. To find out, we can deconstruct the generation process. Using seq_retriever and seq_model instantiated as above, we ask, « What is the name of the oldest tree on Earth? »
query = "what is the name of the oldest tree on Earth?"
inputs = tokenizer(query, return_tensors="pt")
input_ids = inputs["input_ids"]
question_hidden_states = seq_model.question_encoder(input_ids)[0]
docs_dict = seq_retriever(input_ids.numpy(),
question_hidden_states.detach().numpy(),
return_tensors="pt")
doc_scores = torch.bmm(
question_hidden_states.unsqueeze(1),
docs_dict["retrieved_doc_embeds"]
.float().transpose(1, 2)).squeeze(1)
generated = model.generate(
context_input_ids=docs_dict["context_input_ids"],
context_attention_mask=
docs_dict["context_attention_mask"],
doc_scores=doc_scores)
generated_string = tokenizer.batch_decode(
generated,
skip_special_tokens=True)
contexts = tokenizer.batch_decode(
docs_dict["context_input_ids"],
attention_mask=docs_dict["context_attention_mask"],
skip_special_tokens=True)
best_context = contexts[doc_scores.argmax()]
We can code our model to print the « best context » variable to see what was captured.
print(f" based on the retrieved context"
f":nnt {best_context}: n")
Based on the retrieved context:
Prometheus (tree) / In a clonal organism, however, the individual clonal stems are not nearly so old, and no part of the organism is particularly old at any given time. Until 2012, Prometheus was thus the oldest "non-clonal" organism yet discovered, with its innermost, extant rings exceeding 4862 years of age. In the 1950s dendrochronologists were making active efforts to find the oldest living tree species in order to use the analysis of the rings for various research purposes, such as the evaluation of former climates, the dating of archaeological ruins, and addressing the basic scientific question of maximum potential lifespan. Bristlecone pines // what is the name of the oldest tree on earth?
print(f" rag-sequence-nq answers '{query}'"
f" with '{generated_string[0]}'")
We can also print the answer by calling the generated_string
variable. Rag-sequence-nq answers « what is the name of the oldest tree on Earth? » with « Prometheus. »
What Can You Do with RAG?
In the past year and a half, there has been an explosion of LLMs and LLM tools. The base BART model used in Lewis et al. had only 400 million parameters, which is far from today’s crop of LLMs, which typically start in the billion-parameter range for « light » variants. Additionally, many models being trained, merged, and fine-tuned today are multimodal, combining text inputs and outputs with images or other tokenized data sources. Combining RAG with other tools can create complex capabilities, but the underlying models will not be immune to common LLM shortcomings. Issues of flattery, hallucination, and LLM reliability all remain and risk being exacerbated as LLM use grows.
The most obvious applications of RAG are variations of conversational semantic search, but they might also include incorporating multimodal inputs or generating images as part of the output. For example, RAG in LLMs with domain knowledge can create software documentation you can discuss. Or RAG could be used to curate interactive notes in a literature review for a research project or thesis.
By integrating a « chain of thought » reasoning capability, you can take a more agentive approach to enable your models to query the RAG system and assemble more complex lines of inquiry or reasoning.
It’s also crucial to remember that RAG does not solve common LLM pitfalls (hallucination, flattery, etc.) and only serves as a means to mitigate or guide your LLM towards a more specialized response. The endpoints that ultimately matter are specific to your use case, the information you provide to your model, and how the model is fine-tuned.
Kevin Vu manages the Exxact Corp. Blog and collaborates with several talented authors who write about various aspects of Deep Learning.