143 lines
5.6 KiB
Python
143 lines
5.6 KiB
Python
import time
|
|
from typing import List
|
|
|
|
import jieba
|
|
import jieba.analyse
|
|
from lancedb import connect
|
|
from core.model import get_embedding_model, get_chat_model
|
|
from extension.document_loader import DocumentLoader
|
|
from extension.standard import OnlinePrompt, parse_json_string
|
|
|
|
|
|
class RAGPipeline:
|
|
def __init__(
|
|
self,
|
|
db_path: str = "./db/vec",
|
|
default_table_name: str = "documents",
|
|
embedding_model: str = "bge-large",
|
|
llm_model: str = "qwen2.5:7b",
|
|
chunk_size: int = 512,
|
|
chunk_overlap: int = 200,
|
|
):
|
|
self.embeddings = get_embedding_model(embedding_model)
|
|
self.llm = get_chat_model(llm_model)
|
|
self.db = connect(db_path)
|
|
self.default_table_name = default_table_name
|
|
self.document_loader = DocumentLoader(config={"chunk_size": chunk_size, "chunk_overlap": chunk_overlap})
|
|
|
|
def _insert(self, texts: list, extension: str, table_name: str = None):
|
|
table_name = table_name or self.default_table_name
|
|
embeddings = self.embeddings.embed_documents(texts)
|
|
time_stamp = int(time.time())
|
|
data = [
|
|
{
|
|
"text": text.replace("\n", "").replace("\t", "").replace("\r", ""),
|
|
"text_fts":text.replace("\n", "").replace("\t", "").replace("\r", "").replace(" ", ""),
|
|
"vector": embedding,
|
|
"extension": extension,
|
|
"time_stamp": time_stamp
|
|
}
|
|
for text, embedding in zip(texts, embeddings)
|
|
]
|
|
|
|
if table_name in self.db.table_names():
|
|
table = self.db.open_table(table_name)
|
|
table.add(data)
|
|
else:
|
|
# Create table with explicit schema including FTS index
|
|
self.db.create_table(
|
|
table_name,
|
|
data=data,
|
|
mode="overwrite"
|
|
)
|
|
|
|
def insert_document(self, file_path: str, table_name: str = None, extension: str = None):
|
|
self.document_loader.load_content(file_path, extension)
|
|
chunks = self.document_loader.load_and_split()
|
|
texts = [chunk.page_content for chunk in chunks]
|
|
self._insert(texts, self.document_loader.extension, table_name)
|
|
|
|
def insert_text_content(self, content: str, table_name: str = None):
|
|
chunks = self.document_loader.split_text(content)
|
|
texts = [chunk for chunk in chunks]
|
|
self._insert(texts, ".txt", table_name)
|
|
|
|
@staticmethod
|
|
def _extract_keywords(question: str) -> List[str]:
|
|
# Use TF-IDF and TextRank combination for keyword extraction
|
|
tfidf_kws = jieba.analyse.extract_tags(
|
|
question, topK=3, withWeight=False, allowPOS=('n', 'vn', 'v'))
|
|
textrank_kws = jieba.analyse.textrank(
|
|
question, topK=3, withWeight=False, allowPOS=('n', 'vn', 'v'))
|
|
|
|
# Merge and deduplicate keywords
|
|
combined_kws = list(set(tfidf_kws + textrank_kws))
|
|
return [kw for kw in combined_kws if len(kw) > 1] # filter short keywords
|
|
|
|
@staticmethod
|
|
def _build_keyword_condition(keywords: List[str]) -> str:
|
|
if not keywords:
|
|
return ""
|
|
conditions = [f"text LIKE '%{kw}%'" for kw in keywords]
|
|
return " OR ".join(conditions)
|
|
|
|
def _rewrite_question(self, question: str) -> dict:
|
|
# Use LLM to rewrite the question
|
|
op = OnlinePrompt("rewrite_question")
|
|
prompt = op.generate(question)
|
|
response = self.llm.llm(prompt)
|
|
return parse_json_string(response)
|
|
|
|
def query(self, question: str, k: int = 10, table_name: str = None) -> list[dict]:
|
|
# Extract keywords and parse question
|
|
keywords = self._extract_keywords(question)
|
|
rewritten_data = self._rewrite_question(question)
|
|
keywords = [k for k in rewritten_data.get("keywords", []) if k in keywords]
|
|
rewritten_question = rewritten_data.get("rewrite", "")
|
|
|
|
# Vector search with rewritten question
|
|
question_embedding = self.embeddings.embed_query(rewritten_question)
|
|
table_name = table_name or self.default_table_name
|
|
table = self.db.open_table(table_name)
|
|
|
|
indices = table.list_indices()
|
|
index_exists = any(
|
|
index["column_name"] == "text_fts" and index["index_type"] == "INVERTED"
|
|
for index in indices
|
|
)
|
|
|
|
if not index_exists:
|
|
try:
|
|
table.create_fts_index("text_fts")
|
|
except ValueError as e:
|
|
if "Index already exists" in str(e):
|
|
# If index exists but was not detected, try replacing it
|
|
table.create_fts_index("text_fts", replace=True)
|
|
else:
|
|
raise e
|
|
combined = []
|
|
for key in keywords:
|
|
combined += (table.search(query_type="hybrid")
|
|
.vector(question_embedding)
|
|
.text(key)
|
|
.select(["text", "extension", "time_stamp"])
|
|
.limit(k)
|
|
.to_list())
|
|
# print(table.search(query_type="fts")
|
|
# .vector([])
|
|
# .text(key)
|
|
# .select(["text", "extension", "time_stamp"])
|
|
# .limit(k)
|
|
# .to_list())
|
|
# Rerank results (simple time-weighted sort)
|
|
return sorted(combined,
|
|
key=lambda x: -x['time_stamp'])
|
|
|
|
|
|
rag_pipline = RAGPipeline()
|
|
if __name__ == "__main__":
|
|
# 直接插入文本内容的示例
|
|
rag_pipline.insert_text_content("这是一个要嵌入的示例文本。")
|
|
result = rag_pipline.query("示例文本是什么?")
|
|
print(result)
|