135 lines
3.9 KiB
Python
135 lines
3.9 KiB
Python
import base64
|
|
|
|
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
|
from langchain.schema import BaseMessage
|
|
|
|
from core.config import model_conf, conf
|
|
from core.role import User, AI, System
|
|
from core.types import BaseModel
|
|
|
|
user = User()
|
|
ai = AI()
|
|
system = System()
|
|
|
|
|
|
class ChatModel(BaseModel):
|
|
def __init__(self, config: dict):
|
|
super().__init__(config)
|
|
self._model = ChatOpenAI(
|
|
base_url=config["host"],
|
|
api_key=config["key"],
|
|
model=config["name"],
|
|
temperature=config["temperature"],
|
|
extra_body={"enable_thinking":config.get("thinking",False)})
|
|
self.dialogue = []
|
|
|
|
def add_message(self, message: BaseMessage):
|
|
self.dialogue.append(message)
|
|
|
|
def load_chat_history(self, chat_history: list[dict]):
|
|
self.dialogue = []
|
|
for i in chat_history:
|
|
if i["role"] == "user":
|
|
msg = user.generate(i["message"])
|
|
self.add_message(msg)
|
|
elif i["role"] == "system":
|
|
msg = ai.generate(i["message"])
|
|
self.add_message(msg)
|
|
elif i["role"] == "external":
|
|
msg = system.generate(i["message"])
|
|
self.add_message(msg)
|
|
|
|
@staticmethod
|
|
def _parser_llm_result(llm_result):
|
|
content = ""
|
|
for r in llm_result.generations[0]:
|
|
content += r.message.content
|
|
return content
|
|
|
|
def llm(self, message: str) -> str:
|
|
return self._model.invoke(message).content
|
|
|
|
def chat(self, message: str) -> str:
|
|
llm_result = self._model.generate([self.dialogue])
|
|
return self._parser_llm_result(llm_result)
|
|
|
|
|
|
class EmbeddingModel(BaseModel):
|
|
def __init__(self, config: dict):
|
|
super().__init__(config)
|
|
self._model = OpenAIEmbeddings(
|
|
base_url=config["host"],
|
|
api_key=config["key"],
|
|
model=config["name"],
|
|
check_embedding_ctx_length=False)
|
|
|
|
def embed_query(self, text: str):
|
|
return self._model.embed_query(text)
|
|
|
|
def embed_documents(self, texts: list):
|
|
return self._model.embed_documents(texts)
|
|
|
|
|
|
class VisionModel(BaseModel):
|
|
def __init__(self, config: dict):
|
|
super().__init__(config)
|
|
self._model = ChatOpenAI(
|
|
base_url=config["host"],
|
|
api_key=config["key"],
|
|
model=config["name"],
|
|
temperature=config["temperature"])
|
|
self._data = None
|
|
|
|
@property
|
|
def base64(self):
|
|
if self._data is None:
|
|
return ""
|
|
return base64.b64encode(self._data).decode('utf-8')
|
|
|
|
def load_image(self, bytes_data: bytes):
|
|
self._data = bytes_data
|
|
|
|
def _call_vision_function(self, message: str):
|
|
messages = [
|
|
user.generate([
|
|
|
|
{"type": "text", "text": message},
|
|
{"type": "image_url", "image_url": f"data:image/jpeg;base64,{self.base64}"},
|
|
])
|
|
]
|
|
return self._model.invoke(messages).content
|
|
|
|
def _call_vision_messages(self, messages: list):
|
|
return self._model.invoke(messages).content
|
|
|
|
def query(self, message: str) -> str:
|
|
return self._call_vision_function(message)
|
|
|
|
def describe(self):
|
|
return self._call_vision_function("完整详细的描述图片中的信息")
|
|
|
|
|
|
def get_embedding_model(model_type: str = ""):
|
|
return EmbeddingModel(model_conf.get(model_type))
|
|
|
|
|
|
def get_vision_model(model_type: str = ""):
|
|
return VisionModel(model_conf.get(model_type))
|
|
|
|
|
|
def get_chat_model(model_type: str = ""):
|
|
return ChatModel(model_conf.get(model_type))
|
|
|
|
|
|
think_instance = get_chat_model(conf.think_model)
|
|
llm_instance = get_chat_model(conf.llm_model)
|
|
vision_instance = get_vision_model(conf.vision_model)
|
|
|
|
|
|
def run_llm_by_message(message: str):
|
|
llm_instance.load_chat_history([])
|
|
return llm_instance.llm(message)
|
|
|
|
def think_by_message(message: str):
|
|
think_instance.load_chat_history([])
|
|
return think_instance.llm(message) |