Files
chat-bot/core/model.py
lychang 64ce30fdfd init
2025-08-26 09:35:29 +08:00

135 lines
3.9 KiB
Python

import base64
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain.schema import BaseMessage
from core.config import model_conf, conf
from core.role import User, AI, System
from core.types import BaseModel
user = User()
ai = AI()
system = System()
class ChatModel(BaseModel):
def __init__(self, config: dict):
super().__init__(config)
self._model = ChatOpenAI(
base_url=config["host"],
api_key=config["key"],
model=config["name"],
temperature=config["temperature"],
extra_body={"enable_thinking":config.get("thinking",False)})
self.dialogue = []
def add_message(self, message: BaseMessage):
self.dialogue.append(message)
def load_chat_history(self, chat_history: list[dict]):
self.dialogue = []
for i in chat_history:
if i["role"] == "user":
msg = user.generate(i["message"])
self.add_message(msg)
elif i["role"] == "system":
msg = ai.generate(i["message"])
self.add_message(msg)
elif i["role"] == "external":
msg = system.generate(i["message"])
self.add_message(msg)
@staticmethod
def _parser_llm_result(llm_result):
content = ""
for r in llm_result.generations[0]:
content += r.message.content
return content
def llm(self, message: str) -> str:
return self._model.invoke(message).content
def chat(self, message: str) -> str:
llm_result = self._model.generate([self.dialogue])
return self._parser_llm_result(llm_result)
class EmbeddingModel(BaseModel):
def __init__(self, config: dict):
super().__init__(config)
self._model = OpenAIEmbeddings(
base_url=config["host"],
api_key=config["key"],
model=config["name"],
check_embedding_ctx_length=False)
def embed_query(self, text: str):
return self._model.embed_query(text)
def embed_documents(self, texts: list):
return self._model.embed_documents(texts)
class VisionModel(BaseModel):
def __init__(self, config: dict):
super().__init__(config)
self._model = ChatOpenAI(
base_url=config["host"],
api_key=config["key"],
model=config["name"],
temperature=config["temperature"])
self._data = None
@property
def base64(self):
if self._data is None:
return ""
return base64.b64encode(self._data).decode('utf-8')
def load_image(self, bytes_data: bytes):
self._data = bytes_data
def _call_vision_function(self, message: str):
messages = [
user.generate([
{"type": "text", "text": message},
{"type": "image_url", "image_url": f"data:image/jpeg;base64,{self.base64}"},
])
]
return self._model.invoke(messages).content
def _call_vision_messages(self, messages: list):
return self._model.invoke(messages).content
def query(self, message: str) -> str:
return self._call_vision_function(message)
def describe(self):
return self._call_vision_function("完整详细的描述图片中的信息")
def get_embedding_model(model_type: str = ""):
return EmbeddingModel(model_conf.get(model_type))
def get_vision_model(model_type: str = ""):
return VisionModel(model_conf.get(model_type))
def get_chat_model(model_type: str = ""):
return ChatModel(model_conf.get(model_type))
think_instance = get_chat_model(conf.think_model)
llm_instance = get_chat_model(conf.llm_model)
vision_instance = get_vision_model(conf.vision_model)
def run_llm_by_message(message: str):
llm_instance.load_chat_history([])
return llm_instance.llm(message)
def think_by_message(message: str):
think_instance.load_chat_history([])
return think_instance.llm(message)