209 lines
6.7 KiB
Python
209 lines
6.7 KiB
Python
import base64
|
|
import json
|
|
import os
|
|
import re
|
|
|
|
from pysmx.SM3 import hexdigest
|
|
|
|
from core.config import conf
|
|
from core.types import BaseCRUD
|
|
from core.model import run_llm_by_message
|
|
from core.types import BasePrompt
|
|
|
|
|
|
class LocalResourceManager:
|
|
|
|
def __init__(self, base_dir: str = "/docs"):
|
|
self.BASE_DIR = os.getcwd() + base_dir
|
|
|
|
@staticmethod
|
|
def generate_sm3(file_bytes: bytes):
|
|
return hexdigest(file_bytes)
|
|
|
|
def exists(self, resource_type: str, resource_name: str) -> bool:
|
|
path = "\\".join([i for i in [self.BASE_DIR, resource_type, resource_name] if i != ""])
|
|
return os.path.exists(path)
|
|
|
|
@staticmethod
|
|
def _write(path: str, file: bytes):
|
|
with open(path, "wb") as f:
|
|
f.write(file)
|
|
|
|
@staticmethod
|
|
def _read(path: str) -> bytes:
|
|
with open(path, "rb") as f:
|
|
return f.read()
|
|
|
|
@staticmethod
|
|
def _remove(path) -> bool:
|
|
try:
|
|
os.remove(path)
|
|
return True
|
|
except Exception as e:
|
|
return False
|
|
|
|
def create(self, resource_type: str, resource_name: str, file: bytes) -> bool:
|
|
exists = self.exists(resource_type, resource_name)
|
|
if exists:
|
|
return False
|
|
else:
|
|
file_path = "/".join([i for i in [self.BASE_DIR, resource_type, resource_name] if i != ""])
|
|
try:
|
|
self._write(file_path, file)
|
|
return True
|
|
except Exception as e:
|
|
raise Exception(f"Create file failed: {file_path}")
|
|
|
|
def update(self, resource_type: str, resource_name: str, file: bytes) -> bool:
|
|
exists = self.exists(resource_type, resource_name)
|
|
if not exists:
|
|
return False
|
|
else:
|
|
file_path = "/".join([i for i in [self.BASE_DIR, resource_type, resource_name] if i != ""])
|
|
try:
|
|
self._write(file_path, file)
|
|
return True
|
|
except Exception as e:
|
|
raise Exception(f"Update file failed: {file_path}")
|
|
|
|
def get(self, resource_type: str, resource_name: str) -> bytes:
|
|
file_path = "\\".join([i for i in [self.BASE_DIR, resource_type, resource_name] if i != ""])
|
|
exists = self.exists(resource_type, resource_name)
|
|
if exists:
|
|
return self._read(file_path)
|
|
else:
|
|
raise FileNotFoundError(f"File is not exists: {file_path}")
|
|
|
|
def delete(self, resource_type: str, resource_name: str) -> bool:
|
|
file_path = "/".join([i for i in [self.BASE_DIR, resource_type, resource_name] if i != ""])
|
|
exists = self.exists(resource_type, resource_name)
|
|
if not exists:
|
|
return self._remove(file_path)
|
|
else:
|
|
raise FileNotFoundError(f"File is not exists: {file_path}")
|
|
|
|
def download_bytes(self, path: str, file_type: str = ""):
|
|
return self.get(file_type, path)
|
|
|
|
|
|
class ChatFileManager(LocalResourceManager):
|
|
def __init__(self, base_dir: str = "/docs"):
|
|
super().__init__(base_dir)
|
|
self.type = "files"
|
|
self.extension_mapping = json.loads(self.get("json", "file_extension.json"))
|
|
|
|
@staticmethod
|
|
def encode_postfix(content: str) -> str:
|
|
return base64.b64encode(content.encode("utf-8")).decode("utf-8")
|
|
|
|
@staticmethod
|
|
def decode_postfix(content: str) -> str:
|
|
return base64.b64decode(content).decode("utf-8")
|
|
|
|
def parse_file_id(self, file_id):
|
|
postfix = ""
|
|
if "_" in file_id:
|
|
file_id, postfix = file_id.split("_")
|
|
postfix = "." + self.decode_postfix(postfix)
|
|
return file_id, postfix
|
|
|
|
def _generate_file_id(self, file_name: str, file_content: bytes):
|
|
postfix = ""
|
|
if "." in file_name:
|
|
postfix = file_name.split(".")[-1]
|
|
file_id = resource_manager.generate_sm3(file_content)
|
|
return file_id, self.encode_postfix(postfix)
|
|
|
|
def c_get(self, file_id: str) -> (bytes, str):
|
|
file_id, postfix = self.parse_file_id(file_id)
|
|
return self.get(self.type, file_id), postfix
|
|
|
|
def c_create(self, file_name: str, file_content: bytes) -> str:
|
|
file_id, postfix = self._generate_file_id(file_name, file_content)
|
|
if self.exists(self.type, file_id):
|
|
return file_id + "_" + postfix
|
|
if self.create(self.type, file_id, file_content):
|
|
return file_id + "_" + postfix
|
|
else:
|
|
return ""
|
|
|
|
|
|
class DatabaseManager:
|
|
def __init__(self, db_path):
|
|
self.db_path = db_path
|
|
self.prompts = BaseCRUD('prompts', db_path)
|
|
self.schemas = BaseCRUD('schemas', db_path)
|
|
|
|
def init_db(self):
|
|
if not os.path.exists(self.db_path):
|
|
with self.prompts.get_connection() as conn:
|
|
conn.execute('''
|
|
CREATE TABLE prompts (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
name TEXT UNIQUE NOT NULL,
|
|
content TEXT NOT NULL,
|
|
variables TEXT, -- 存储变量信息
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
''')
|
|
conn.execute('''
|
|
CREATE TABLE schemas (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
name TEXT UNIQUE NOT NULL,
|
|
content TEXT NOT NULL,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
''')
|
|
conn.commit()
|
|
|
|
|
|
class Prompt(BasePrompt):
|
|
def _load_prompt(self):
|
|
return resource_manager.get("prompt", self.name)
|
|
|
|
def _load_schema(self):
|
|
return resource_manager.get("schema", self.name)
|
|
|
|
|
|
class OnlinePrompt(BasePrompt):
|
|
def _load_prompt(self):
|
|
result = db_manager.prompts.get_by_name(self.name)
|
|
if result is None:
|
|
return None
|
|
result = dict(result)
|
|
return result.get("content")
|
|
|
|
def _load_schema(self):
|
|
result = db_manager.schemas.get_by_name(self.name)
|
|
if result is None:
|
|
return None
|
|
result = dict(result)
|
|
return result.get("content")
|
|
|
|
|
|
db_manager = DatabaseManager(conf.db_uri)
|
|
db_manager.init_db()
|
|
|
|
resource_manager = LocalResourceManager()
|
|
chat_file_manager = ChatFileManager()
|
|
|
|
|
|
def run_llm_by_template(message: str, template_name: str):
|
|
pmt = OnlinePrompt(template_name)
|
|
prompt = pmt.generate(message)
|
|
return run_llm_by_message(prompt)
|
|
|
|
|
|
def parse_json_string(text: str):
|
|
if "</think>" in text:
|
|
text = text.split("</think>")[-1]
|
|
if "```" in text:
|
|
text = re.findall("```(.*?){0,1}\n(.*?)\n```", text, re.S)[0][1]
|
|
else:
|
|
if "'" in text and '"' not in text:
|
|
text = text.replace("'", "\"")
|
|
try:
|
|
return json.loads(text)
|
|
except Exception:
|
|
return {}
|