124 lines
4.4 KiB
Python
124 lines
4.4 KiB
Python
from extension.rag import rag_pipline
|
|
from core.model import run_llm_by_message
|
|
from core.types import BaseEngine
|
|
from extension.mcp import mcp_engine
|
|
from extension.standard import chat_file_manager, run_llm_by_template
|
|
|
|
|
|
class ChatEngine(BaseEngine):
|
|
def __init__(self):
|
|
super().__init__("chat_engine")
|
|
self.dialogue = []
|
|
self.file = {}
|
|
self._rag = rag_pipline
|
|
self._body = None
|
|
|
|
@property
|
|
def context(self):
|
|
contexts = self.dialogue[-16:] if len(self.dialogue) > 16 else self.dialogue
|
|
return "\n##\n\n".join([msg["role"]+": "+msg["message"] for msg in contexts])
|
|
|
|
def set_body(self, body):
|
|
self._body = body
|
|
|
|
def load_chat_history(self, chat_history):
|
|
self.dialogue = []
|
|
self.file.clear()
|
|
for msg in chat_history:
|
|
if msg["role"] in ["user", "system", "external"]:
|
|
self.dialogue.append(msg)
|
|
elif msg["role"] == "uploader":
|
|
file_id, _ = chat_file_manager.parse_file_id(msg["message"])
|
|
self.file["uploader"] = file_id
|
|
elif msg["role"] == "sender":
|
|
self.file["sender"] = msg["message"]
|
|
|
|
def add_message(self, message: str, role: str):
|
|
self.dialogue.append({"role": role, "message": message})
|
|
|
|
def add_user_message(self, message: str):
|
|
self.add_message(message, "user")
|
|
|
|
def add_system_message(self, message: str):
|
|
self.add_message(message, "system")
|
|
|
|
def load_file(self, message: str):
|
|
try:
|
|
file_id, ext = chat_file_manager.parse_file_id(message)
|
|
self._rag.insert_document(file_id, file_id, ext)
|
|
self.file["uploader"] = file_id
|
|
return {"message": "文件上传成功!请继续提问。", "role": "system"}
|
|
except Exception as e:
|
|
return {"message": "文件上传失败!请联系管理员。", "role": "system"}
|
|
|
|
def run(self, message: str, file_name=None, plugin_type=None) -> dict:
|
|
current_context = self.context
|
|
self.add_user_message(message)
|
|
mcp_engine.set_context(current_context)
|
|
mcp_engine.set_file(self.file.get("uploader"))
|
|
|
|
try:
|
|
mcp_data = mcp_engine.run(message)
|
|
action = mcp_data["tool"]
|
|
if action == "answer":
|
|
data = mcp_data.get("data", "")
|
|
if not data:
|
|
data = self._generate_chat(message, current_context, "")["data"]
|
|
return self._chat(data)
|
|
print(mcp_data)
|
|
external_data = mcp_data.get("data", "")
|
|
generate_data = self._generate_chat(message, current_context, external_data)
|
|
return self._data_parse(generate_data)
|
|
|
|
except Exception as e:
|
|
return {"message": f"处理错误: {str(e)}", "role": "system"}
|
|
|
|
def _chat(self, answer: str):
|
|
self.add_system_message(answer)
|
|
return {"message": answer, "role": "system"}
|
|
|
|
@staticmethod
|
|
def _generate_chat(message: str, memory: str, external: str):
|
|
context = ""
|
|
if memory:
|
|
context += f"<content>{memory}</content>\n"
|
|
if external:
|
|
context += f"<external>{external}</external>\n"
|
|
context += f"<input>{message}</input>"
|
|
result = run_llm_by_template(context,"text_generate")
|
|
print(result)
|
|
return {"type": "text", "data": result }
|
|
|
|
def _data_parse(self, data: dict) -> dict:
|
|
if data["type"] == "text":
|
|
self.add_system_message(data["data"])
|
|
return {"message": data["data"], "role": "system"}
|
|
elif data["type"] == "file":
|
|
self.file["sender"] = data["data"]
|
|
return {"message": data["data"], "role": "sender"}
|
|
else:
|
|
raise ValueError("Unsupported data type")
|
|
|
|
|
|
class ModelManager:
|
|
def __init__(self):
|
|
self.allow_list = []
|
|
|
|
def chat(self, message: str, message_type: str = "text", chat_history=None):
|
|
ce = ChatEngine()
|
|
ce.set_allow_list(self.allow_list)
|
|
ce.load_chat_history(chat_history if chat_history else [])
|
|
if message_type == "text":
|
|
return ce.run(message)
|
|
elif message_type == "file":
|
|
return ce.load_file(message)
|
|
else:
|
|
raise Exception("Message type not supported!")
|
|
|
|
@staticmethod
|
|
def generate(prompt):
|
|
return run_llm_by_message(prompt)
|
|
|
|
|
|
model_manager = ModelManager()
|