445 lines
15 KiB
Python
445 lines
15 KiB
Python
import json
|
|
|
|
from typing import Dict, List, Optional
|
|
from dataclasses import dataclass
|
|
from abc import ABC, abstractmethod
|
|
from urllib.parse import quote
|
|
|
|
import requests
|
|
from core.config import conf
|
|
from core.model import run_llm_by_message, think_by_message
|
|
from core.types import BaseCRUD, BaseEngine
|
|
from extension.standard import parse_json_string, OnlinePrompt, db_manager
|
|
from function.context import rag_search
|
|
from function.weather import weather_search
|
|
from function.web_tool import web_scraper
|
|
|
|
|
|
class ProtocolAdapter(ABC):
|
|
name: str
|
|
|
|
@property
|
|
@abstractmethod
|
|
def name(self):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def handle_request(self, request: dict) -> dict:
|
|
pass
|
|
|
|
def request(self, request: dict) -> str:
|
|
result = self.handle_request(request)
|
|
if result["status"] == 500:
|
|
return ""
|
|
else:
|
|
return result["message"]
|
|
|
|
|
|
class HttpProtocolAdapter(ProtocolAdapter):
|
|
def __init__(self, secure: bool = False):
|
|
self.secure = secure
|
|
self.headers = {
|
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36 Edg/133.0.0.0",
|
|
"Accept": "*/*",
|
|
"Connection": "keep-alive"
|
|
}
|
|
|
|
@property
|
|
def name(self):
|
|
return 'https' if self.secure else 'http'
|
|
|
|
def handle_request(self, request: dict) -> dict:
|
|
method = request['method']
|
|
uri = f"{self.name}://{request['uri']}"
|
|
try:
|
|
if method == 'get':
|
|
params = request['parameters']
|
|
uri += f"?{'&'.join([f'{k}={quote(str(v))}' for k, v in params.items()])}"
|
|
return {"message": web_scraper.get_uri_resource(uri), "status": 200}
|
|
elif method == 'post':
|
|
data = requests.post(uri, json=request['parameters'], headers=self.headers).json()
|
|
return {'status': 200, 'message': str(data)}
|
|
else:
|
|
return {"status": 500, "message": f"Unsupported method: {method}"}
|
|
except Exception as e:
|
|
return {"status": 500, "message": str(e)}
|
|
|
|
|
|
class LocalProtocolAdapter(ProtocolAdapter):
|
|
@property
|
|
def name(self):
|
|
return "local"
|
|
|
|
def handle_request(self, request: dict) -> dict:
|
|
return {"status": "success", "message": "Local request handled successfully"}
|
|
|
|
|
|
@dataclass
|
|
class MCPService:
|
|
instance_name: str
|
|
name: str
|
|
endpoint: str
|
|
id: int = -1
|
|
protocol: Optional[ProtocolAdapter] = None
|
|
status: int = 0
|
|
config: dict = None
|
|
|
|
def set_protocol(self, protocol: str):
|
|
secure = protocol.endswith('s')
|
|
base_protocol = protocol[:-1] if secure else protocol
|
|
|
|
if base_protocol == 'http':
|
|
self.protocol = HttpProtocolAdapter(secure=secure)
|
|
elif base_protocol == "local":
|
|
self.protocol = LocalProtocolAdapter()
|
|
else:
|
|
raise ValueError(f"Unsupported protocol: {protocol}")
|
|
|
|
def start_service(self) -> dict:
|
|
self.status = 1
|
|
return {
|
|
'status': 'success',
|
|
'message': f'Service {self.id} started successfully',
|
|
'service': {
|
|
'id': self.id,
|
|
'name': self.name,
|
|
'endpoint': self.endpoint,
|
|
'protocol': self.protocol.__class__.__name__,
|
|
'status': 'running'
|
|
}
|
|
}
|
|
|
|
def stop_service(self) -> dict:
|
|
self.status = 0
|
|
return {
|
|
'status': 'success',
|
|
'message': f'Service {self.id} stopped successfully',
|
|
'service': {
|
|
'id': self.id,
|
|
'name': self.name,
|
|
'endpoint': self.endpoint,
|
|
'protocol': self.protocol.__class__.__name__,
|
|
'status': 'stopped'
|
|
}
|
|
}
|
|
|
|
def execute(self, message: str) -> str:
|
|
schema = self._get_schema()
|
|
params = self._generate_parameters(message, schema)
|
|
try:
|
|
return self.protocol.request({
|
|
"method": self.config.get('method', ""),
|
|
"parameters": params,
|
|
"uri": self.endpoint
|
|
})
|
|
except Exception as e:
|
|
return f"Execution error: {str(e)}"
|
|
|
|
def _get_schema(self) -> str:
|
|
schema_name = self.config.get('schema', "")
|
|
if schema_name:
|
|
schema = db_manager.schemas.get_by_name(schema_name)
|
|
if schema:
|
|
schema = dict(schema)
|
|
return schema.get('content') if schema else ""
|
|
|
|
return ""
|
|
|
|
@staticmethod
|
|
def _generate_parameters(message: str, schema: str) -> dict:
|
|
prompt = OnlinePrompt("parameter_generate").generate(message)
|
|
prompt = prompt.replace("{{schema}}", schema)
|
|
return parse_json_string(run_llm_by_message(prompt))
|
|
|
|
|
|
class MCPManager:
|
|
def __init__(self, db_path: str):
|
|
self.db = BaseCRUD('services', db_path)
|
|
self._init_db()
|
|
self.services: Dict[str, MCPService] = {}
|
|
self._sync_from_db()
|
|
|
|
def _init_db(self):
|
|
with self.db.get_connection() as conn:
|
|
conn.execute('''
|
|
CREATE TABLE IF NOT EXISTS services (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
name TEXT UNIQUE NOT NULL,
|
|
instance_name TEXT NOT NULL,
|
|
instance_status INTEGER NOT NULL,
|
|
endpoint TEXT NOT NULL,
|
|
protocol TEXT NOT NULL,
|
|
config JSON,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
''')
|
|
|
|
def _sync_from_db(self):
|
|
services = self.db.get_all()['items']
|
|
for service in services:
|
|
mcp_service = MCPService(
|
|
id=service['id'],
|
|
instance_name=service['instance_name'],
|
|
name=service['name'],
|
|
endpoint=service['endpoint'],
|
|
status=service['instance_status'],
|
|
config=json.loads(service['config'])
|
|
)
|
|
mcp_service.set_protocol(service['protocol'])
|
|
self.services[str(mcp_service.id)] = mcp_service
|
|
|
|
def list_services(self) -> List[dict]:
|
|
return [dict(item) for item in self.db.get_all()['items']]
|
|
|
|
def get_status(self) -> dict:
|
|
with self.db.get_connection() as conn:
|
|
total_services = conn.execute('SELECT COUNT(*) FROM services').fetchone()[0]
|
|
return {"total_services": total_services}
|
|
|
|
def register_service(self, name: str, endpoint: str, protocol: str, config: dict = None) -> None:
|
|
if protocol not in ['http', 'https', 'local']:
|
|
raise ValueError(f"Unsupported protocol: {protocol}")
|
|
|
|
instance_name = f"mcp-{name.lower().replace(' ', '-')}"
|
|
config = config if config else {}
|
|
|
|
_id = self.db.create(
|
|
name=name,
|
|
instance_name=instance_name,
|
|
instance_status=0,
|
|
endpoint=endpoint,
|
|
protocol=protocol,
|
|
config=json.dumps(config)
|
|
)
|
|
|
|
service = MCPService(
|
|
id=_id,
|
|
instance_name=instance_name,
|
|
name=name,
|
|
endpoint=endpoint,
|
|
config=config
|
|
)
|
|
service.set_protocol(protocol)
|
|
self.services[str(_id)] = service
|
|
|
|
def unregister_service(self, service_id: str) -> None:
|
|
self.db.delete(service_id)
|
|
if service_id in self.services:
|
|
del self.services[service_id]
|
|
|
|
def create_service(self, name: str, endpoint: str, protocol: str, config: dict = None) -> dict:
|
|
try:
|
|
self.register_service(name, endpoint, protocol, config)
|
|
return {'status': 'success', 'message': f'Service {name} created successfully'}
|
|
except Exception as e:
|
|
return {'status': 'error', 'message': str(e)}
|
|
|
|
def delete_service(self, service_id: str) -> dict:
|
|
try:
|
|
self.unregister_service(service_id)
|
|
return {'status': 'success', 'message': f'Service {service_id} deleted successfully'}
|
|
except Exception as e:
|
|
return {'status': 'error', 'message': str(e)}
|
|
|
|
def start_service(self, service_id: str) -> dict:
|
|
if service_id not in self.services:
|
|
return {'status': 'error', 'message': f"Service {service_id} not found"}
|
|
|
|
service = self.services[service_id]
|
|
try:
|
|
result = service.start_service()
|
|
self.db.update(service.id, instance_status=1)
|
|
return result
|
|
except Exception as e:
|
|
return {'status': 'error', 'message': str(e)}
|
|
|
|
def stop_service(self, service_id: str) -> dict:
|
|
if service_id not in self.services:
|
|
return {'status': 'error', 'message': f"Service {service_id} not found"}
|
|
|
|
service = self.services[service_id]
|
|
try:
|
|
result = service.stop_service()
|
|
self.db.update(service.id, instance_status=0)
|
|
return result
|
|
except Exception as e:
|
|
return {'status': 'error', 'message': str(e)}
|
|
|
|
def restart_service(self, service_id: str) -> dict:
|
|
if service_id not in self.services:
|
|
return {'status': 'error', 'message': f"Service {service_id} not found"}
|
|
|
|
try:
|
|
self.stop_service(service_id)
|
|
return self.start_service(service_id)
|
|
except Exception as e:
|
|
return {'status': 'error', 'message': str(e)}
|
|
|
|
def check_health(self, service_id: str) -> dict:
|
|
if service_id not in self.services:
|
|
return {'status': 'error', 'message': f"Service {service_id} not found"}
|
|
|
|
service = self.services[service_id]
|
|
is_healthy = service.status == 1
|
|
return {
|
|
'status': 'success',
|
|
'healthy': is_healthy,
|
|
'message': f'Service {service_id} is {"healthy" if is_healthy else "unhealthy"}'
|
|
}
|
|
|
|
def update_service(self, service_id: str, **kwargs) -> dict:
|
|
if service_id not in self.services:
|
|
return {'status': 'error', 'message': f"Service {service_id} not found"}
|
|
|
|
service = self.services[service_id]
|
|
update_fields = {}
|
|
|
|
if 'name' in kwargs:
|
|
service.name = kwargs['name']
|
|
update_fields['name'] = kwargs['name']
|
|
|
|
if 'endpoint' in kwargs:
|
|
service.endpoint = kwargs['endpoint']
|
|
update_fields['endpoint'] = kwargs['endpoint']
|
|
|
|
if 'protocol' in kwargs:
|
|
service.set_protocol(kwargs['protocol'])
|
|
update_fields['protocol'] = kwargs['protocol']
|
|
|
|
if 'config' in kwargs:
|
|
service.config = kwargs['config']
|
|
update_fields['config'] = json.dumps(kwargs['config'])
|
|
|
|
if 'status' in kwargs:
|
|
service.status = kwargs['status']
|
|
update_fields['instance_status'] = kwargs['status']
|
|
|
|
try:
|
|
if update_fields:
|
|
self.db.update(service.id, **update_fields)
|
|
|
|
return {
|
|
'status': 'success',
|
|
'message': f'Service {service_id} updated successfully',
|
|
'service': {
|
|
'id': service.id,
|
|
'name': service.name,
|
|
'endpoint': service.endpoint,
|
|
'protocol': service.protocol.name,
|
|
'status': service.status,
|
|
'config': service.config
|
|
}
|
|
}
|
|
except Exception as e:
|
|
return {'status': 'error', 'message': str(e)}
|
|
|
|
|
|
mcp_manager = MCPManager(conf.db_uri)
|
|
|
|
|
|
class MCPPredictError(Exception):
|
|
pass
|
|
|
|
|
|
class MCPEngine(BaseEngine):
|
|
def __init__(self):
|
|
super().__init__("mcp_engine")
|
|
self._manager = mcp_manager
|
|
self._context = None
|
|
self._file = None
|
|
self.pool = {}
|
|
self.services = {}
|
|
|
|
def set_context(self, context: str):
|
|
self._context = context
|
|
|
|
def set_file(self, file: str):
|
|
self._file = file
|
|
|
|
@staticmethod
|
|
def _load_services_info() -> Dict[str, dict]:
|
|
services = mcp_manager.list_services()
|
|
return {i['name']: i for i in services if i['instance_status'] == 1}
|
|
|
|
# def _rewrite(self, message: str):
|
|
# op = OnlinePrompt("rewrite_question")
|
|
# op.set_external(self._context)
|
|
# prompt = op.generate(message)
|
|
# response = run_llm_by_message(prompt)
|
|
# try:
|
|
# return parse_json_string(response)
|
|
# except Exception as e:
|
|
# return {"rewrite": message, "keywords": []}
|
|
|
|
def _predict(self, message: str) -> List[dict]:
|
|
self.services = self._load_services_info()
|
|
external_str = "- [chat](优先级2) 根据上下文回答一些简单的问题\n- [context](优先级5) 涉及前文,前面内容时,必须调用\n"
|
|
|
|
for name, info in self.services.items():
|
|
if info['protocol'] == 'local':
|
|
tool = self.tool_pool.get(name)
|
|
if tool:
|
|
self.pool[name] = tool.description
|
|
external_str += f"- [{name}](优先级4) {tool.description} \n"
|
|
else:
|
|
config = json.loads(info.get('config', "{}"))
|
|
desc = config.get('description', '')
|
|
self.pool[name] = desc
|
|
external_str += f"- [{name}](优先级3) {desc} \n"
|
|
|
|
op = OnlinePrompt(self.name)
|
|
op.set_external(external_str)
|
|
prompt = op.generate(message)
|
|
response = think_by_message(prompt)
|
|
try:
|
|
return parse_json_string(response)
|
|
except Exception as e:
|
|
return []
|
|
|
|
def _run_task(self, message: str, task: dict, external_data: str = "") -> dict:
|
|
tool_name = task.get("tool", "")
|
|
question = task.get("question", "")
|
|
|
|
if not tool_name:
|
|
return {"tool": "error", "output": "text", "data": f"Tool:{tool_name} execute error"}
|
|
|
|
task_prompt = f"main_task: {message}\nsub_task: {question}"
|
|
if external_data:
|
|
task_prompt = f"{external_data}\n{task_prompt}"
|
|
if tool_name == "chat":
|
|
return {"tool": "answer", "output": "text", "data": task.get("answer", "")}
|
|
elif tool_name == "context":
|
|
return {"tool": "context", "output": "text", "data": self._context}
|
|
if tool_name in self.tool_pool:
|
|
tool = self.tool_pool[tool_name]
|
|
tool.set_file_name(self._file)
|
|
return tool.execute(task_prompt)
|
|
|
|
if tool_name in self.services:
|
|
service = self._manager.services[str(self.services[tool_name]['id'])]
|
|
return {"tool": tool_name, "output": "text", "data": service.execute(task_prompt)}
|
|
|
|
return {"tool": "chat", "output": "text", "data": message}
|
|
|
|
def run(self, message: str, file_name: str = None, plugin_type: str = None) -> dict:
|
|
# question = self._rewrite(message)["rewrite"]
|
|
question = message
|
|
tasks = self._predict(question)
|
|
if isinstance(tasks, dict):
|
|
tasks = [tasks]
|
|
|
|
external_data = ""
|
|
result = {"tool": "error", "output": "text", "data": "MCP engine execute error"}
|
|
print(tasks)
|
|
for task in tasks:
|
|
result = self._run_task(message, task, external_data)
|
|
if result['tool'] not in ["chat", "error"]:
|
|
external_data += result['data']
|
|
return result
|
|
|
|
|
|
mcp_engine = MCPEngine()
|
|
mcp_engine.add_tool(weather_search)
|
|
mcp_engine.add_tool(web_scraper)
|
|
mcp_engine.add_tool(rag_search)
|