71 lines
2.1 KiB
Python
71 lines
2.1 KiB
Python
from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
from core.model import vision_instance
|
|
from core.types import BaseLoader
|
|
|
|
|
|
|
|
|
|
class ImageLoader:
|
|
def __init__(self, file_path: str):
|
|
self.file_path = file_path
|
|
|
|
def load(self):
|
|
from langchain.schema import Document
|
|
|
|
with open(self.file_path, "rb") as f:
|
|
image_data = f.read()
|
|
|
|
vision_instance.load_image(image_data)
|
|
description = vision_instance.describe()
|
|
|
|
return [Document(
|
|
page_content=description,
|
|
metadata={"source": self.file_path}
|
|
)]
|
|
|
|
|
|
class DocumentLoader(BaseLoader):
|
|
SUPPORTED_LOADERS = {
|
|
".pdf": PyPDFLoader,
|
|
".txt": TextLoader,
|
|
".docx": Docx2txtLoader,
|
|
".doc": Docx2txtLoader,
|
|
".jpeg": ImageLoader,
|
|
".jpg": ImageLoader,
|
|
".png": ImageLoader
|
|
}
|
|
|
|
def __init__(self, config=None):
|
|
super().__init__()
|
|
config = config or {}
|
|
self.file_path = None
|
|
self.extension = None
|
|
self._document = None
|
|
self.chunk_size = config.get("chunk_size", 512)
|
|
self.chunk_overlap = config.get("chunk_overlap", 200)
|
|
self.text_splitter = RecursiveCharacterTextSplitter(
|
|
chunk_size=self.chunk_size,
|
|
chunk_overlap=self.chunk_overlap,
|
|
)
|
|
|
|
@property
|
|
def loader(self):
|
|
if self.extension not in self.SUPPORTED_LOADERS:
|
|
raise ValueError(f"不支持的文件类型: {self.extension}")
|
|
return self.SUPPORTED_LOADERS[self.extension](self.file_path)
|
|
|
|
def load_and_split(self):
|
|
return self.text_splitter.split_documents(self._document)
|
|
|
|
def load_content(self, file_path,extension=None):
|
|
self.extension = extension if extension else ""
|
|
self.file_path = "./docs/files/" + file_path
|
|
with open(self.file_path, "rb") as f:
|
|
self.content = f.read()
|
|
|
|
self._document = self.loader.load()
|
|
|
|
def split_text(self, content: str) -> list[str]:
|
|
return self.text_splitter.split_text(content)
|