RAG 系统概述
RAG(Retrieval-Augmented Generation)是一种结合检索和生成的 AI 技术,通过检索相关文档来增强语言模型的回答质量。相比纯粹的 LLM,RAG 系统具有:
- 知识更新灵活:无需重新训练模型
- 降低幻觉现象:基于真实文档生成答案
- 成本效益高:减少 token 使用量
- 可追溯性强:能够引用信息来源
系统架构设计
核心组件
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.llms import OpenAI
from langchain.chains import RetrievalQA
from langchain.document_loaders import DirectoryLoader
import os
class RAGSystem:
def __init__(self,
documents_path: str,
persist_directory: str = "./chroma_db"):
self.documents_path = documents_path
self.persist_directory = persist_directory
self.embeddings = OpenAIEmbeddings()
self.llm = OpenAI(temperature=0)
self.vectorstore = None
self.qa_chain = None
def load_documents(self):
"""加载文档"""
# 支持多种文档格式
loaders = {
'*.pdf': 'PyPDFLoader',
'*.txt': 'TextLoader',
'*.md': 'UnstructuredMarkdownLoader',
'*.docx': 'Docx2txtLoader'
}
documents = []
for pattern, loader_name in loaders.items():
loader = DirectoryLoader(
self.documents_path,
glob=pattern,
loader_cls=globals()[loader_name]
)
documents.extend(loader.load())
return documents
def process_documents(self, documents):
"""处理文档:分块、清洗"""
# 智能文本分割
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
length_function=len,
separators=["\n\n", "\n", " ", ""]
)
texts = text_splitter.split_documents(documents)
# 清洗文本
for text in texts:
text.page_content = self.clean_text(text.page_content)
return texts
def clean_text(self, text: str) -> str:
"""文本清洗"""
# 移除多余空白
text = ' '.join(text.split())
# 移除特殊字符
text = text.replace('\x00', '')
return text
def create_vectorstore(self, texts):
"""创建向量存储"""
self.vectorstore = Chroma.from_documents(
documents=texts,
embedding=self.embeddings,
persist_directory=self.persist_directory
)
self.vectorstore.persist()
def setup_qa_chain(self):
"""设置问答链"""
self.qa_chain = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff",
retriever=self.vectorstore.as_retriever(
search_type="similarity",
search_kwargs={"k": 4}
),
return_source_documents=True
)
文档处理优化
1. 智能分块策略
from typing import List, Dict
import re
class SmartTextSplitter:
def __init__(self,
chunk_size: int = 1000,
chunk_overlap: int = 200):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
def split_by_structure(self, text: str) -> List[str]:
"""基于文档结构的智能分割"""
chunks = []
# 识别章节
chapters = re.split(r'\n#{1,3}\s', text)
for chapter in chapters:
if len(chapter) > self.chunk_size:
# 进一步分割长章节
sub_chunks = self.split_by_paragraph(chapter)
chunks.extend(sub_chunks)
else:
chunks.append(chapter)
return self.add_overlap(chunks)
def split_by_paragraph(self, text: str) -> List[str]:
"""按段落分割"""
paragraphs = text.split('\n\n')
chunks = []
current_chunk = ""
for para in paragraphs:
if len(current_chunk) + len(para) <= self.chunk_size:
current_chunk += para + "\n\n"
else:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = para + "\n\n"
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
def add_overlap(self, chunks: List[str]) -> List[str]:
"""添加重叠部分"""
overlapped_chunks = []
for i, chunk in enumerate(chunks):
if i > 0:
# 添加前一个块的结尾
prev_end = chunks[i-1][-self.chunk_overlap:]
chunk = prev_end + " " + chunk
if i < len(chunks) - 1:
# 添加下一个块的开头
next_start = chunks[i+1][:self.chunk_overlap]
chunk = chunk + " " + next_start
overlapped_chunks.append(chunk)
return overlapped_chunks
2. 元数据增强
from datetime import datetime
class MetadataEnricher:
def enrich_documents(self, documents):
"""为文档添加丰富的元数据"""
for doc in documents:
# 基础元数据
doc.metadata['processed_at'] = datetime.now().isoformat()
doc.metadata['char_count'] = len(doc.page_content)
doc.metadata['word_count'] = len(doc.page_content.split())
# 内容分析
doc.metadata['has_code'] = '```' in doc.page_content
doc.metadata['has_table'] = '|' in doc.page_content
doc.metadata['language'] = self.detect_language(doc.page_content)
# 主题提取
doc.metadata['keywords'] = self.extract_keywords(doc.page_content)
doc.metadata['summary'] = self.generate_summary(doc.page_content)
return documents
def extract_keywords(self, text: str, top_k: int = 5) -> List[str]:
"""提取关键词"""
from sklearn.feature_extraction.text import TfidfVectorizer
# 简单的 TF-IDF 关键词提取
vectorizer = TfidfVectorizer(
max_features=top_k,
stop_words='english'
)
try:
X = vectorizer.fit_transform([text])
keywords = vectorizer.get_feature_names_out()
return list(keywords)
except:
return []
检索优化策略
1. 混合检索
from langchain.retrievers import EnsembleRetriever
from langchain.retrievers import BM25Retriever
class HybridRetriever:
def __init__(self, vectorstore, documents):
# 向量检索器
self.vector_retriever = vectorstore.as_retriever(
search_kwargs={"k": 5}
)
# BM25 检索器(关键词匹配)
self.bm25_retriever = BM25Retriever.from_documents(documents)
self.bm25_retriever.k = 5
# 混合检索器
self.ensemble_retriever = EnsembleRetriever(
retrievers=[self.vector_retriever, self.bm25_retriever],
weights=[0.6, 0.4] # 向量检索权重更高
)
def retrieve(self, query: str) -> List[Document]:
"""执行混合检索"""
return self.ensemble_retriever.get_relevant_documents(query)
2. 重排序(Re-ranking)
from sentence_transformers import CrossEncoder
class DocumentReranker:
def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2'):
self.model = CrossEncoder(model_name)
def rerank(self, query: str, documents: List[Document],
top_k: int = 3) -> List[Document]:
"""对检索结果重新排序"""
# 准备输入对
pairs = [[query, doc.page_content] for doc in documents]
# 计算相关性分数
scores = self.model.predict(pairs)
# 排序并返回 top-k
doc_score_pairs = list(zip(documents, scores))
doc_score_pairs.sort(key=lambda x: x[1], reverse=True)
reranked_docs = [doc for doc, _ in doc_score_pairs[:top_k]]
return reranked_docs
提示工程优化
1. 动态提示模板
from langchain.prompts import PromptTemplate
class DynamicPromptBuilder:
def __init__(self):
self.templates = {
"technical": """
你是一个技术专家。基于以下文档内容回答用户问题。
如果文档中没有相关信息,请明确说明。
文档内容:
{context}
用户问题:{question}
请提供详细的技术解答,包括:
1. 核心概念解释
2. 实现步骤(如果适用)
3. 注意事项
回答:
""",
"summary": """
请基于以下文档内容,为用户问题提供简洁的摘要式回答。
文档内容:
{context}
用户问题:{question}
回答要求:
- 简洁明了,不超过 3 段
- 突出重点信息
- 适合快速阅读
回答:
""",
"analytical": """
作为数据分析师,请基于文档内容提供深入分析。
文档内容:
{context}
用户问题:{question}
分析维度:
1. 数据洞察
2. 趋势分析
3. 建议措施
分析结果:
"""
}
def get_prompt(self, query: str, prompt_type: str = None) -> PromptTemplate:
"""根据查询类型选择合适的提示模板"""
if not prompt_type:
prompt_type = self.detect_query_type(query)
template = self.templates.get(prompt_type, self.templates["technical"])
return PromptTemplate(
input_variables=["context", "question"],
template=template
)
def detect_query_type(self, query: str) -> str:
"""检测查询类型"""
query_lower = query.lower()
if any(word in query_lower for word in ["总结", "概括", "简述"]):
return "summary"
elif any(word in query_lower for word in ["分析", "比较", "评估"]):
return "analytical"
else:
return "technical"
2. 上下文优化
class ContextOptimizer:
def __init__(self, max_context_length: int = 3000):
self.max_context_length = max_context_length
def optimize_context(self, documents: List[Document],
query: str) -> str:
"""优化上下文,确保相关性和长度平衡"""
# 按相关性排序
sorted_docs = self.sort_by_relevance(documents, query)
# 智能截断
context_parts = []
current_length = 0
for doc in sorted_docs:
content = doc.page_content
# 提取最相关的段落
relevant_parts = self.extract_relevant_parts(content, query)
for part in relevant_parts:
if current_length + len(part) <= self.max_context_length:
context_parts.append(part)
current_length += len(part)
else:
# 截断到最大长度
remaining = self.max_context_length - current_length
context_parts.append(part[:remaining] + "...")
break
if current_length >= self.max_context_length:
break
return "\n\n---\n\n".join(context_parts)
def extract_relevant_parts(self, content: str,
query: str) -> List[str]:
"""提取文档中最相关的部分"""
# 分句
sentences = content.split('.')
# 计算每个句子的相关性
query_words = set(query.lower().split())
scored_sentences = []
for i, sentence in enumerate(sentences):
sentence_words = set(sentence.lower().split())
score = len(query_words & sentence_words)
# 考虑上下文
context_range = 2
start = max(0, i - context_range)
end = min(len(sentences), i + context_range + 1)
context = '.'.join(sentences[start:end])
scored_sentences.append((context, score))
# 返回得分最高的部分
scored_sentences.sort(key=lambda x: x[1], reverse=True)
return [s[0] for s in scored_sentences[:3]]
高级功能实现
1. 多轮对话支持
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
class ConversationalRAG:
def __init__(self, vectorstore, llm):
self.vectorstore = vectorstore
self.llm = llm
self.memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True
)
self.qa_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=vectorstore.as_retriever(),
memory=self.memory,
return_source_documents=True
)
def chat(self, question: str) -> Dict[str, Any]:
"""处理对话"""
result = self.qa_chain({"question": question})
# 格式化响应
response = {
"answer": result["answer"],
"sources": [
{
"content": doc.page_content[:200] + "...",
"metadata": doc.metadata
}
for doc in result["source_documents"]
],
"chat_history": self.get_chat_history()
}
return response
def get_chat_history(self) -> List[Dict[str, str]]:
"""获取对话历史"""
messages = self.memory.chat_memory.messages
history = []
for msg in messages:
history.append({
"role": "human" if msg.type == "human" else "assistant",
"content": msg.content
})
return history
2. 多模态支持
from langchain.document_loaders import PyPDFLoader
from PIL import Image
import pytesseract
class MultiModalRAG:
def __init__(self, base_rag_system):
self.base_rag = base_rag_system
self.image_processor = ImageProcessor()
def process_multimodal_document(self, file_path: str):
"""处理多模态文档"""
if file_path.endswith('.pdf'):
# 提取文本和图片
texts, images = self.extract_from_pdf(file_path)
# 处理图片
image_texts = []
for img in images:
img_text = self.image_processor.extract_text(img)
img_description = self.image_processor.describe_image(img)
image_texts.append(f"{img_text}\n图片描述:{img_description}")
# 合并所有内容
full_content = "\n".join(texts + image_texts)
return self.base_rag.process_documents([
Document(page_content=full_content, metadata={"source": file_path})
])
def extract_from_pdf(self, pdf_path: str):
"""从 PDF 提取文本和图片"""
import fitz # PyMuPDF
texts = []
images = []
pdf_document = fitz.open(pdf_path)
for page_num in range(pdf_document.page_count):
page = pdf_document[page_num]
# 提取文本
texts.append(page.get_text())
# 提取图片
image_list = page.get_images(full=True)
for img_index, img in enumerate(image_list):
xref = img[0]
pix = fitz.Pixmap(pdf_document, xref)
if pix.n - pix.alpha < 4: # GRAY or RGB
img_data = pix.tobytes("png")
images.append(Image.open(io.BytesIO(img_data)))
pix = None
return texts, images
class ImageProcessor:
def extract_text(self, image: Image) -> str:
"""从图片中提取文本(OCR)"""
return pytesseract.image_to_string(image, lang='chi_sim+eng')
def describe_image(self, image: Image) -> str:
"""生成图片描述"""
# 这里可以集成图像描述模型
# 例如使用 CLIP 或 BLIP
return "图片内容描述"
性能优化
1. 缓存机制
from functools import lru_cache
import hashlib
import redis
class CachedRAG:
def __init__(self, rag_system, redis_host='localhost', redis_port=6379):
self.rag = rag_system
self.cache = redis.Redis(host=redis_host, port=redis_port, db=0)
self.cache_ttl = 3600 # 1小时
def query(self, question: str) -> Dict[str, Any]:
"""带缓存的查询"""
# 生成缓存键
cache_key = self._generate_cache_key(question)
# 尝试从缓存获取
cached_result = self.cache.get(cache_key)
if cached_result:
return json.loads(cached_result)
# 执行查询
result = self.rag.query(question)
# 缓存结果
self.cache.setex(
cache_key,
self.cache_ttl,
json.dumps(result)
)
return result
def _generate_cache_key(self, question: str) -> str:
"""生成缓存键"""
# 使用问题的哈希值作为键
question_hash = hashlib.md5(question.encode()).hexdigest()
return f"rag_cache:{question_hash}"
@lru_cache(maxsize=100)
def get_similar_questions(self, question: str) -> List[str]:
"""获取相似问题(用于缓存命中率优化)"""
# 这里可以实现相似问题的检索逻辑
pass
2. 批量处理
import asyncio
from typing import List, Dict
class BatchRAG:
def __init__(self, rag_system, batch_size: int = 10):
self.rag = rag_system
self.batch_size = batch_size
async def batch_query(self, questions: List[str]) -> List[Dict[str, Any]]:
"""批量处理查询"""
results = []
for i in range(0, len(questions), self.batch_size):
batch = questions[i:i + self.batch_size]
# 并行处理批次中的查询
batch_results = await asyncio.gather(
*[self._async_query(q) for q in batch]
)
results.extend(batch_results)
return results
async def _async_query(self, question: str) -> Dict[str, Any]:
"""异步查询"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self.rag.query, question)
部署最佳实践
1. API 服务
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn
app = FastAPI()
# 初始化 RAG 系统
rag_system = RAGSystem(documents_path="./documents")
rag_system.initialize()
class QueryRequest(BaseModel):
question: str
max_results: int = 3
include_sources: bool = True
class QueryResponse(BaseModel):
answer: str
sources: List[Dict] = []
confidence: float
@app.post("/query", response_model=QueryResponse)
async def query(request: QueryRequest):
try:
result = rag_system.query(
request.question,
k=request.max_results
)
response = QueryResponse(
answer=result["answer"],
sources=result["sources"] if request.include_sources else [],
confidence=result.get("confidence", 0.8)
)
return response
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/update_documents")
async def update_documents(file_paths: List[str]):
"""更新文档库"""
try:
rag_system.add_documents(file_paths)
return {"status": "success", "message": f"Added {len(file_paths)} documents"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
2. Docker 部署
FROM python:3.9-slim
WORKDIR /app
# 安装依赖
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# 复制代码
COPY . .
# 创建数据目录
RUN mkdir -p /app/documents /app/chroma_db
# 暴露端口
EXPOSE 8000
# 启动命令
CMD ["python", "main.py"]
总结
构建一个生产级的 RAG 系统需要考虑多个方面:
- 文档处理:智能分块、元数据增强
- 检索优化:混合检索、重排序
- 生成质量:提示工程、上下文优化
- 系统性能:缓存、批处理、异步处理
- 可扩展性:模块化设计、API 服务
通过合理的架构设计和优化策略,RAG 系统可以为用户提供准确、可靠的知识服务。随着技术的发展,RAG 系统将在企业知识管理、智能客服、教育辅助等领域发挥越来越重要的作用。