RAG : Retrieval-Augmented Generation#
Author : Amir Esfandiari
introduction#
Generative artificial intelligence (AI) excels at creating text responses based on large language models (LLMs) where the AI is trained on a massive number of data points. The good news is that the generated text is often easy to read and provides detailed responses that are broadly applicable to the questions asked of the software, often called prompts.
The bad news is that the information used to generate the response is limited to the information used to train the AI, often a generalized LLM. The LLM’s data may be weeks, months, or years out of date and in a corporate AI chatbot may not include specific information about the organization’s products or services. That can lead to incorrect responses that erode confidence in the technology among customers and employees.
What is Retrieval-Augmented Generation (RAG)?RAG ?#
That’s where retrieval-augmented generation (RAG) comes in. RAG provides a way to optimize the output of an LLM with targeted information without modifying the underlying model itself; that targeted information can be more up-to-date than the LLM as well as specific to a particular organization and industry. That means the generative AI system can provide more contextually appropriate answers to prompts as well as base those answers on extremely current data.
RAG first came to the attention of generative AI developers after the publication of “Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks,” a 2020 paper published by Patrick Lewis and a team at Facebook AI Research. The RAG concept has been embraced by many academic and industry researchers, who see it as a way to significantly improve the value of generative AI systems.
Retrieval-Augmented Generation Explained#
Consider a sports league that wants fans and the media to be able to use chat to access its data and answer questions about players, teams, the sport’s history and rules, and current stats and standings. A generalized LLM could answer questions about the history and rules or perhaps describe a particular team’s stadium. It wouldn’t be able to discuss last night’s game or provide current information about a particular athlete’s injury because the LLM wouldn’t have that information—and given that an LLM takes significant computing horsepower to retrain, it isn’t feasible to keep the model current.
In addition to the large, fairly static LLM, the sports league owns or can access many other information sources, including databases, data warehouses, documents containing player bios, and news feeds that discuss each game in depth. RAG lets the generative AI ingest this information. Now, the chat can provide information that’s more timely, more contextually appropriate, and more accurate.
Simply put, RAG helps LLMs give better answers.
How Does Retrieval-Augmented Generation Work?#
Consider all the information that an organization has—the structured databases, the unstructured PDFs and other documents, the blogs, the news feeds, the chat transcripts from past customer service sessions. In RAG, this vast quantity of dynamic data is translated into a common format and stored in a knowledge library that’s accessible to the generative AI system.
The data in that knowledge library is then processed into numerical representations using a special type of algorithm called an embedded language model and stored in a vector database, which can be quickly searched and used to retrieve the correct contextual information.
RAG and Large Language Models (LLMs)#
Now, say an end user sends the generative AI system a specific prompt, for example, “Where will tonight’s game be played, who are the starting players, and what are reporters saying about the matchup?” The query is transformed into a vector and used to query the vector database, which retrieves information relevant to that question’s context. That contextual information plus the original prompt are then fed into the LLM, which generates a text response based on both its somewhat out-of-date generalized knowledge and the extremely timely contextual information.
Interestingly, while the process of training the generalized LLM is time-consuming and costly, updates to the RAG model are just the opposite. New data can be loaded into the embedded language model and translated into vectors on a continuous, incremental basis. In fact, the answers from the entire generative AI system can be fed back into the RAG model, improving its performance and accuracy, because, in effect, it knows how it has already answered a similar question.
An additional benefit of RAG is that by using the vector database, the generative AI can provide the specific source of data cited in its answer—something LLMs can’t do. Therefore, if there’s an inaccuracy in the generative AI’s output, the document that contains that erroneous information can be quickly identified and corrected, and then the corrected information can be fed into the vector database.
In short, RAG provides timeliness, context, and accuracy grounded in evidence to generative AI, going beyond what the LLM itself can provide.
Challenges of Retrieval-Augmented Generation#
Because RAG is a relatively new technology, first proposed in 2020, AI developers are still learning how to best implement its information retrieval mechanisms in generative AI. Some key challenges are
Improving organizational knowledge and understanding of RAG because it’s so new
Increasing costs; while generative AI with RAG will be more expensive to implement than an LLM on its own, this route is less costly than frequently retraining the LLM itself
Determining how to best model the structured and unstructured data within the knowledge library and vector database
Developing requirements for a process to incrementally feed data into the RAG system
Putting processes in place to handle reports of inaccuracies and to correct or delete those information sources in the RAG system
Implementation#
Understanding RAG (Retrieval-Augmented Generation)#
This notebook explains RAG (Retrieval-Augmented Generation), a technique that combines document retrieval with language model generation to create more accurate and factual responses.
Required Libraries#
# Import necessary libraries
import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM
import faiss
import numpy as np
1. Document Preparation#
First, let’s prepare some sample documents and create a simple document store:
# Sample documents
documents = [
"RAG stands for Retrieval-Augmented Generation in machine learning.",
"RAG combines retrieval and generation models for better accuracy.",
"Vector databases are crucial components in RAG systems.",
"Document chunking is important for effective RAG implementation.",
"RAG helps reduce hallucination in language models."
]
def prepare_documents(docs):
"""
Prepare documents for embedding
"""
return docs
2. Creating Embeddings#
We’ll use SentenceTransformers to create embeddings:
def create_embeddings(documents):
"""
Create embeddings for documents using SentenceTransformer
"""
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
embeddings = model.encode(documents)
return embeddings
# Create embeddings for our documents
document_embeddings = create_embeddings(documents)
3. Building the Vector Store#
Now we’ll create a simple vector store using FAISS:
def build_vector_store(embeddings):
"""
Build a FAISS index for fast similarity search
"""
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings.astype('float32'))
return index
# Build the vector store
vector_store = build_vector_store(document_embeddings)
4. Implementing RAG#
Here’s a basic RAG implementation:
class SimpleRAG:
def __init__(self, documents, vector_store):
self.documents = documents
self.vector_store = vector_store
self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
def retrieve(self, query, k=2):
"""
Retrieve relevant documents for the query
"""
# Create query embedding
query_embedding = self.embedding_model.encode([query])
# Search in vector store
distances, indices = self.vector_store.search(
query_embedding.astype('float32'),
k
)
# Return relevant documents
return [self.documents[i] for i in indices[0]]
def generate(self, query, retrieved_docs):
"""
Generate response using retrieved documents
This is a simplified version - in practice, you'd use a proper LLM
"""
context = "\n".join(retrieved_docs)
prompt = f"""
Context: {context}
Question: {query}
Answer:"""
return prompt # In practice, you'd send this to an LLM
5. Using the RAG System#
Let’s try out our RAG system:
# Initialize RAG
rag = SimpleRAG(documents, vector_store)
# Example query
query = "What is RAG and why is it useful?"
# Retrieve relevant documents
retrieved_docs = rag.retrieve(query)
print("Retrieved Documents:")
for i, doc in enumerate(retrieved_docs):
print(f"{i+1}. {doc}")
# Generate response
prompt = rag.generate(query, retrieved_docs)
print("\nGenerated Prompt:")
print(prompt)
6. Advanced Considerations#
Here are some important considerations when implementing RAG in production:
Document Chunking:
def chunk_documents(documents, chunk_size=512):
"""
Split documents into smaller chunks
This is a simplified version
"""
chunks = []
for doc in documents:
# In practice, you'd want more sophisticated chunking
if len(doc) > chunk_size:
chunks.extend([doc[i:i+chunk_size]
for i in range(0, len(doc), chunk_size)])
else:
chunks.append(doc)
return chunks
Hybrid Search:
def hybrid_search(query, vector_results, keyword_results, alpha=0.5):
"""
Combine vector and keyword search results
This is a simplified version
"""
# In practice, you'd implement a more sophisticated
# combination strategy
combined_results = {}
for doc, score in vector_results.items():
combined_results[doc] = score * alpha
for doc, score in keyword_results.items():
if doc in combined_results:
combined_results[doc] += score * (1-alpha)
else:
combined_results[doc] = score * (1-alpha)
return combined_results
7. Best Practices#
Regular Index Updates:
Implement a strategy to update your vector store regularly
Consider incremental updates for efficiency
Error Handling:
Implement robust error handling for embedding and retrieval
Have fallback strategies when retrieval fails
Monitoring:
Track retrieval quality metrics
Monitor embedding and generation latency
Implement feedback loops for continuous improvement
Conclusion#
RAG is a powerful technique that can significantly improve the accuracy and reliability of language model outputs. This notebook provided a basic implementation, but production systems would need additional considerations for scalability, reliability, and performance optimization.
Remember to:
Choose appropriate embedding models
Implement efficient document chunking
Consider hybrid search approaches
Monitor and optimize performance
Regularly update your knowledge base