Files
chat-bot/extension/rag.py
lychang 64ce30fdfd init
2025-08-26 09:35:29 +08:00

143 lines
5.6 KiB
Python

import time
from typing import List
import jieba
import jieba.analyse
from lancedb import connect
from core.model import get_embedding_model, get_chat_model
from extension.document_loader import DocumentLoader
from extension.standard import OnlinePrompt, parse_json_string
class RAGPipeline:
def __init__(
self,
db_path: str = "./db/vec",
default_table_name: str = "documents",
embedding_model: str = "bge-large",
llm_model: str = "qwen2.5:7b",
chunk_size: int = 512,
chunk_overlap: int = 200,
):
self.embeddings = get_embedding_model(embedding_model)
self.llm = get_chat_model(llm_model)
self.db = connect(db_path)
self.default_table_name = default_table_name
self.document_loader = DocumentLoader(config={"chunk_size": chunk_size, "chunk_overlap": chunk_overlap})
def _insert(self, texts: list, extension: str, table_name: str = None):
table_name = table_name or self.default_table_name
embeddings = self.embeddings.embed_documents(texts)
time_stamp = int(time.time())
data = [
{
"text": text.replace("\n", "").replace("\t", "").replace("\r", ""),
"text_fts":text.replace("\n", "").replace("\t", "").replace("\r", "").replace(" ", ""),
"vector": embedding,
"extension": extension,
"time_stamp": time_stamp
}
for text, embedding in zip(texts, embeddings)
]
if table_name in self.db.table_names():
table = self.db.open_table(table_name)
table.add(data)
else:
# Create table with explicit schema including FTS index
self.db.create_table(
table_name,
data=data,
mode="overwrite"
)
def insert_document(self, file_path: str, table_name: str = None, extension: str = None):
self.document_loader.load_content(file_path, extension)
chunks = self.document_loader.load_and_split()
texts = [chunk.page_content for chunk in chunks]
self._insert(texts, self.document_loader.extension, table_name)
def insert_text_content(self, content: str, table_name: str = None):
chunks = self.document_loader.split_text(content)
texts = [chunk for chunk in chunks]
self._insert(texts, ".txt", table_name)
@staticmethod
def _extract_keywords(question: str) -> List[str]:
# Use TF-IDF and TextRank combination for keyword extraction
tfidf_kws = jieba.analyse.extract_tags(
question, topK=3, withWeight=False, allowPOS=('n', 'vn', 'v'))
textrank_kws = jieba.analyse.textrank(
question, topK=3, withWeight=False, allowPOS=('n', 'vn', 'v'))
# Merge and deduplicate keywords
combined_kws = list(set(tfidf_kws + textrank_kws))
return [kw for kw in combined_kws if len(kw) > 1] # filter short keywords
@staticmethod
def _build_keyword_condition(keywords: List[str]) -> str:
if not keywords:
return ""
conditions = [f"text LIKE '%{kw}%'" for kw in keywords]
return " OR ".join(conditions)
def _rewrite_question(self, question: str) -> dict:
# Use LLM to rewrite the question
op = OnlinePrompt("rewrite_question")
prompt = op.generate(question)
response = self.llm.llm(prompt)
return parse_json_string(response)
def query(self, question: str, k: int = 10, table_name: str = None) -> list[dict]:
# Extract keywords and parse question
keywords = self._extract_keywords(question)
rewritten_data = self._rewrite_question(question)
keywords = [k for k in rewritten_data.get("keywords", []) if k in keywords]
rewritten_question = rewritten_data.get("rewrite", "")
# Vector search with rewritten question
question_embedding = self.embeddings.embed_query(rewritten_question)
table_name = table_name or self.default_table_name
table = self.db.open_table(table_name)
indices = table.list_indices()
index_exists = any(
index["column_name"] == "text_fts" and index["index_type"] == "INVERTED"
for index in indices
)
if not index_exists:
try:
table.create_fts_index("text_fts")
except ValueError as e:
if "Index already exists" in str(e):
# If index exists but was not detected, try replacing it
table.create_fts_index("text_fts", replace=True)
else:
raise e
combined = []
for key in keywords:
combined += (table.search(query_type="hybrid")
.vector(question_embedding)
.text(key)
.select(["text", "extension", "time_stamp"])
.limit(k)
.to_list())
# print(table.search(query_type="fts")
# .vector([])
# .text(key)
# .select(["text", "extension", "time_stamp"])
# .limit(k)
# .to_list())
# Rerank results (simple time-weighted sort)
return sorted(combined,
key=lambda x: -x['time_stamp'])
rag_pipline = RAGPipeline()
if __name__ == "__main__":
# 直接插入文本内容的示例
rag_pipline.insert_text_content("这是一个要嵌入的示例文本。")
result = rag_pipline.query("示例文本是什么?")
print(result)