Files
chat-bot/extension/standard.py
lychang 64ce30fdfd init
2025-08-26 09:35:29 +08:00

209 lines
6.7 KiB
Python

import base64
import json
import os
import re
from pysmx.SM3 import hexdigest
from core.config import conf
from core.types import BaseCRUD
from core.model import run_llm_by_message
from core.types import BasePrompt
class LocalResourceManager:
def __init__(self, base_dir: str = "/docs"):
self.BASE_DIR = os.getcwd() + base_dir
@staticmethod
def generate_sm3(file_bytes: bytes):
return hexdigest(file_bytes)
def exists(self, resource_type: str, resource_name: str) -> bool:
path = "\\".join([i for i in [self.BASE_DIR, resource_type, resource_name] if i != ""])
return os.path.exists(path)
@staticmethod
def _write(path: str, file: bytes):
with open(path, "wb") as f:
f.write(file)
@staticmethod
def _read(path: str) -> bytes:
with open(path, "rb") as f:
return f.read()
@staticmethod
def _remove(path) -> bool:
try:
os.remove(path)
return True
except Exception as e:
return False
def create(self, resource_type: str, resource_name: str, file: bytes) -> bool:
exists = self.exists(resource_type, resource_name)
if exists:
return False
else:
file_path = "/".join([i for i in [self.BASE_DIR, resource_type, resource_name] if i != ""])
try:
self._write(file_path, file)
return True
except Exception as e:
raise Exception(f"Create file failed: {file_path}")
def update(self, resource_type: str, resource_name: str, file: bytes) -> bool:
exists = self.exists(resource_type, resource_name)
if not exists:
return False
else:
file_path = "/".join([i for i in [self.BASE_DIR, resource_type, resource_name] if i != ""])
try:
self._write(file_path, file)
return True
except Exception as e:
raise Exception(f"Update file failed: {file_path}")
def get(self, resource_type: str, resource_name: str) -> bytes:
file_path = "\\".join([i for i in [self.BASE_DIR, resource_type, resource_name] if i != ""])
exists = self.exists(resource_type, resource_name)
if exists:
return self._read(file_path)
else:
raise FileNotFoundError(f"File is not exists: {file_path}")
def delete(self, resource_type: str, resource_name: str) -> bool:
file_path = "/".join([i for i in [self.BASE_DIR, resource_type, resource_name] if i != ""])
exists = self.exists(resource_type, resource_name)
if not exists:
return self._remove(file_path)
else:
raise FileNotFoundError(f"File is not exists: {file_path}")
def download_bytes(self, path: str, file_type: str = ""):
return self.get(file_type, path)
class ChatFileManager(LocalResourceManager):
def __init__(self, base_dir: str = "/docs"):
super().__init__(base_dir)
self.type = "files"
self.extension_mapping = json.loads(self.get("json", "file_extension.json"))
@staticmethod
def encode_postfix(content: str) -> str:
return base64.b64encode(content.encode("utf-8")).decode("utf-8")
@staticmethod
def decode_postfix(content: str) -> str:
return base64.b64decode(content).decode("utf-8")
def parse_file_id(self, file_id):
postfix = ""
if "_" in file_id:
file_id, postfix = file_id.split("_")
postfix = "." + self.decode_postfix(postfix)
return file_id, postfix
def _generate_file_id(self, file_name: str, file_content: bytes):
postfix = ""
if "." in file_name:
postfix = file_name.split(".")[-1]
file_id = resource_manager.generate_sm3(file_content)
return file_id, self.encode_postfix(postfix)
def c_get(self, file_id: str) -> (bytes, str):
file_id, postfix = self.parse_file_id(file_id)
return self.get(self.type, file_id), postfix
def c_create(self, file_name: str, file_content: bytes) -> str:
file_id, postfix = self._generate_file_id(file_name, file_content)
if self.exists(self.type, file_id):
return file_id + "_" + postfix
if self.create(self.type, file_id, file_content):
return file_id + "_" + postfix
else:
return ""
class DatabaseManager:
def __init__(self, db_path):
self.db_path = db_path
self.prompts = BaseCRUD('prompts', db_path)
self.schemas = BaseCRUD('schemas', db_path)
def init_db(self):
if not os.path.exists(self.db_path):
with self.prompts.get_connection() as conn:
conn.execute('''
CREATE TABLE prompts (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE NOT NULL,
content TEXT NOT NULL,
variables TEXT, -- 存储变量信息
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')
conn.execute('''
CREATE TABLE schemas (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE NOT NULL,
content TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')
conn.commit()
class Prompt(BasePrompt):
def _load_prompt(self):
return resource_manager.get("prompt", self.name)
def _load_schema(self):
return resource_manager.get("schema", self.name)
class OnlinePrompt(BasePrompt):
def _load_prompt(self):
result = db_manager.prompts.get_by_name(self.name)
if result is None:
return None
result = dict(result)
return result.get("content")
def _load_schema(self):
result = db_manager.schemas.get_by_name(self.name)
if result is None:
return None
result = dict(result)
return result.get("content")
db_manager = DatabaseManager(conf.db_uri)
db_manager.init_db()
resource_manager = LocalResourceManager()
chat_file_manager = ChatFileManager()
def run_llm_by_template(message: str, template_name: str):
pmt = OnlinePrompt(template_name)
prompt = pmt.generate(message)
return run_llm_by_message(prompt)
def parse_json_string(text: str):
if "</think>" in text:
text = text.split("</think>")[-1]
if "```" in text:
text = re.findall("```(.*?){0,1}\n(.*?)\n```", text, re.S)[0][1]
else:
if "'" in text and '"' not in text:
text = text.replace("'", "\"")
try:
return json.loads(text)
except Exception:
return {}