Files
chat-bot/extension/mcp.py
lychang 64ce30fdfd init
2025-08-26 09:35:29 +08:00

445 lines
15 KiB
Python

import json
from typing import Dict, List, Optional
from dataclasses import dataclass
from abc import ABC, abstractmethod
from urllib.parse import quote
import requests
from core.config import conf
from core.model import run_llm_by_message, think_by_message
from core.types import BaseCRUD, BaseEngine
from extension.standard import parse_json_string, OnlinePrompt, db_manager
from function.context import rag_search
from function.weather import weather_search
from function.web_tool import web_scraper
class ProtocolAdapter(ABC):
name: str
@property
@abstractmethod
def name(self):
pass
@abstractmethod
def handle_request(self, request: dict) -> dict:
pass
def request(self, request: dict) -> str:
result = self.handle_request(request)
if result["status"] == 500:
return ""
else:
return result["message"]
class HttpProtocolAdapter(ProtocolAdapter):
def __init__(self, secure: bool = False):
self.secure = secure
self.headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36 Edg/133.0.0.0",
"Accept": "*/*",
"Connection": "keep-alive"
}
@property
def name(self):
return 'https' if self.secure else 'http'
def handle_request(self, request: dict) -> dict:
method = request['method']
uri = f"{self.name}://{request['uri']}"
try:
if method == 'get':
params = request['parameters']
uri += f"?{'&'.join([f'{k}={quote(str(v))}' for k, v in params.items()])}"
return {"message": web_scraper.get_uri_resource(uri), "status": 200}
elif method == 'post':
data = requests.post(uri, json=request['parameters'], headers=self.headers).json()
return {'status': 200, 'message': str(data)}
else:
return {"status": 500, "message": f"Unsupported method: {method}"}
except Exception as e:
return {"status": 500, "message": str(e)}
class LocalProtocolAdapter(ProtocolAdapter):
@property
def name(self):
return "local"
def handle_request(self, request: dict) -> dict:
return {"status": "success", "message": "Local request handled successfully"}
@dataclass
class MCPService:
instance_name: str
name: str
endpoint: str
id: int = -1
protocol: Optional[ProtocolAdapter] = None
status: int = 0
config: dict = None
def set_protocol(self, protocol: str):
secure = protocol.endswith('s')
base_protocol = protocol[:-1] if secure else protocol
if base_protocol == 'http':
self.protocol = HttpProtocolAdapter(secure=secure)
elif base_protocol == "local":
self.protocol = LocalProtocolAdapter()
else:
raise ValueError(f"Unsupported protocol: {protocol}")
def start_service(self) -> dict:
self.status = 1
return {
'status': 'success',
'message': f'Service {self.id} started successfully',
'service': {
'id': self.id,
'name': self.name,
'endpoint': self.endpoint,
'protocol': self.protocol.__class__.__name__,
'status': 'running'
}
}
def stop_service(self) -> dict:
self.status = 0
return {
'status': 'success',
'message': f'Service {self.id} stopped successfully',
'service': {
'id': self.id,
'name': self.name,
'endpoint': self.endpoint,
'protocol': self.protocol.__class__.__name__,
'status': 'stopped'
}
}
def execute(self, message: str) -> str:
schema = self._get_schema()
params = self._generate_parameters(message, schema)
try:
return self.protocol.request({
"method": self.config.get('method', ""),
"parameters": params,
"uri": self.endpoint
})
except Exception as e:
return f"Execution error: {str(e)}"
def _get_schema(self) -> str:
schema_name = self.config.get('schema', "")
if schema_name:
schema = db_manager.schemas.get_by_name(schema_name)
if schema:
schema = dict(schema)
return schema.get('content') if schema else ""
return ""
@staticmethod
def _generate_parameters(message: str, schema: str) -> dict:
prompt = OnlinePrompt("parameter_generate").generate(message)
prompt = prompt.replace("{{schema}}", schema)
return parse_json_string(run_llm_by_message(prompt))
class MCPManager:
def __init__(self, db_path: str):
self.db = BaseCRUD('services', db_path)
self._init_db()
self.services: Dict[str, MCPService] = {}
self._sync_from_db()
def _init_db(self):
with self.db.get_connection() as conn:
conn.execute('''
CREATE TABLE IF NOT EXISTS services (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE NOT NULL,
instance_name TEXT NOT NULL,
instance_status INTEGER NOT NULL,
endpoint TEXT NOT NULL,
protocol TEXT NOT NULL,
config JSON,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')
def _sync_from_db(self):
services = self.db.get_all()['items']
for service in services:
mcp_service = MCPService(
id=service['id'],
instance_name=service['instance_name'],
name=service['name'],
endpoint=service['endpoint'],
status=service['instance_status'],
config=json.loads(service['config'])
)
mcp_service.set_protocol(service['protocol'])
self.services[str(mcp_service.id)] = mcp_service
def list_services(self) -> List[dict]:
return [dict(item) for item in self.db.get_all()['items']]
def get_status(self) -> dict:
with self.db.get_connection() as conn:
total_services = conn.execute('SELECT COUNT(*) FROM services').fetchone()[0]
return {"total_services": total_services}
def register_service(self, name: str, endpoint: str, protocol: str, config: dict = None) -> None:
if protocol not in ['http', 'https', 'local']:
raise ValueError(f"Unsupported protocol: {protocol}")
instance_name = f"mcp-{name.lower().replace(' ', '-')}"
config = config if config else {}
_id = self.db.create(
name=name,
instance_name=instance_name,
instance_status=0,
endpoint=endpoint,
protocol=protocol,
config=json.dumps(config)
)
service = MCPService(
id=_id,
instance_name=instance_name,
name=name,
endpoint=endpoint,
config=config
)
service.set_protocol(protocol)
self.services[str(_id)] = service
def unregister_service(self, service_id: str) -> None:
self.db.delete(service_id)
if service_id in self.services:
del self.services[service_id]
def create_service(self, name: str, endpoint: str, protocol: str, config: dict = None) -> dict:
try:
self.register_service(name, endpoint, protocol, config)
return {'status': 'success', 'message': f'Service {name} created successfully'}
except Exception as e:
return {'status': 'error', 'message': str(e)}
def delete_service(self, service_id: str) -> dict:
try:
self.unregister_service(service_id)
return {'status': 'success', 'message': f'Service {service_id} deleted successfully'}
except Exception as e:
return {'status': 'error', 'message': str(e)}
def start_service(self, service_id: str) -> dict:
if service_id not in self.services:
return {'status': 'error', 'message': f"Service {service_id} not found"}
service = self.services[service_id]
try:
result = service.start_service()
self.db.update(service.id, instance_status=1)
return result
except Exception as e:
return {'status': 'error', 'message': str(e)}
def stop_service(self, service_id: str) -> dict:
if service_id not in self.services:
return {'status': 'error', 'message': f"Service {service_id} not found"}
service = self.services[service_id]
try:
result = service.stop_service()
self.db.update(service.id, instance_status=0)
return result
except Exception as e:
return {'status': 'error', 'message': str(e)}
def restart_service(self, service_id: str) -> dict:
if service_id not in self.services:
return {'status': 'error', 'message': f"Service {service_id} not found"}
try:
self.stop_service(service_id)
return self.start_service(service_id)
except Exception as e:
return {'status': 'error', 'message': str(e)}
def check_health(self, service_id: str) -> dict:
if service_id not in self.services:
return {'status': 'error', 'message': f"Service {service_id} not found"}
service = self.services[service_id]
is_healthy = service.status == 1
return {
'status': 'success',
'healthy': is_healthy,
'message': f'Service {service_id} is {"healthy" if is_healthy else "unhealthy"}'
}
def update_service(self, service_id: str, **kwargs) -> dict:
if service_id not in self.services:
return {'status': 'error', 'message': f"Service {service_id} not found"}
service = self.services[service_id]
update_fields = {}
if 'name' in kwargs:
service.name = kwargs['name']
update_fields['name'] = kwargs['name']
if 'endpoint' in kwargs:
service.endpoint = kwargs['endpoint']
update_fields['endpoint'] = kwargs['endpoint']
if 'protocol' in kwargs:
service.set_protocol(kwargs['protocol'])
update_fields['protocol'] = kwargs['protocol']
if 'config' in kwargs:
service.config = kwargs['config']
update_fields['config'] = json.dumps(kwargs['config'])
if 'status' in kwargs:
service.status = kwargs['status']
update_fields['instance_status'] = kwargs['status']
try:
if update_fields:
self.db.update(service.id, **update_fields)
return {
'status': 'success',
'message': f'Service {service_id} updated successfully',
'service': {
'id': service.id,
'name': service.name,
'endpoint': service.endpoint,
'protocol': service.protocol.name,
'status': service.status,
'config': service.config
}
}
except Exception as e:
return {'status': 'error', 'message': str(e)}
mcp_manager = MCPManager(conf.db_uri)
class MCPPredictError(Exception):
pass
class MCPEngine(BaseEngine):
def __init__(self):
super().__init__("mcp_engine")
self._manager = mcp_manager
self._context = None
self._file = None
self.pool = {}
self.services = {}
def set_context(self, context: str):
self._context = context
def set_file(self, file: str):
self._file = file
@staticmethod
def _load_services_info() -> Dict[str, dict]:
services = mcp_manager.list_services()
return {i['name']: i for i in services if i['instance_status'] == 1}
# def _rewrite(self, message: str):
# op = OnlinePrompt("rewrite_question")
# op.set_external(self._context)
# prompt = op.generate(message)
# response = run_llm_by_message(prompt)
# try:
# return parse_json_string(response)
# except Exception as e:
# return {"rewrite": message, "keywords": []}
def _predict(self, message: str) -> List[dict]:
self.services = self._load_services_info()
external_str = "- [chat](优先级2) 根据上下文回答一些简单的问题\n- [context](优先级5) 涉及前文,前面内容时,必须调用\n"
for name, info in self.services.items():
if info['protocol'] == 'local':
tool = self.tool_pool.get(name)
if tool:
self.pool[name] = tool.description
external_str += f"- [{name}](优先级4) {tool.description} \n"
else:
config = json.loads(info.get('config', "{}"))
desc = config.get('description', '')
self.pool[name] = desc
external_str += f"- [{name}](优先级3) {desc} \n"
op = OnlinePrompt(self.name)
op.set_external(external_str)
prompt = op.generate(message)
response = think_by_message(prompt)
try:
return parse_json_string(response)
except Exception as e:
return []
def _run_task(self, message: str, task: dict, external_data: str = "") -> dict:
tool_name = task.get("tool", "")
question = task.get("question", "")
if not tool_name:
return {"tool": "error", "output": "text", "data": f"Tool:{tool_name} execute error"}
task_prompt = f"main_task: {message}\nsub_task: {question}"
if external_data:
task_prompt = f"{external_data}\n{task_prompt}"
if tool_name == "chat":
return {"tool": "answer", "output": "text", "data": task.get("answer", "")}
elif tool_name == "context":
return {"tool": "context", "output": "text", "data": self._context}
if tool_name in self.tool_pool:
tool = self.tool_pool[tool_name]
tool.set_file_name(self._file)
return tool.execute(task_prompt)
if tool_name in self.services:
service = self._manager.services[str(self.services[tool_name]['id'])]
return {"tool": tool_name, "output": "text", "data": service.execute(task_prompt)}
return {"tool": "chat", "output": "text", "data": message}
def run(self, message: str, file_name: str = None, plugin_type: str = None) -> dict:
# question = self._rewrite(message)["rewrite"]
question = message
tasks = self._predict(question)
if isinstance(tasks, dict):
tasks = [tasks]
external_data = ""
result = {"tool": "error", "output": "text", "data": "MCP engine execute error"}
print(tasks)
for task in tasks:
result = self._run_task(message, task, external_data)
if result['tool'] not in ["chat", "error"]:
external_data += result['data']
return result
mcp_engine = MCPEngine()
mcp_engine.add_tool(weather_search)
mcp_engine.add_tool(web_scraper)
mcp_engine.add_tool(rag_search)