38 lines
1013 B
Python
38 lines
1013 B
Python
#!/usr/bin/env python
|
|
# -*- coding:utf-8 -*-
|
|
# @Filename: context.py
|
|
# @Author: lychang
|
|
# @Time: 7/5/2023 5:53 PM
|
|
from sympy.parsing.maxima import sub_dict
|
|
|
|
from extension.rag import rag_pipline
|
|
|
|
from core.types import BaseTool
|
|
|
|
|
|
|
|
|
|
|
|
class RAGSearch(BaseTool):
|
|
def __init__(self):
|
|
name = "rag_search"
|
|
description = "提及在文件中搜索时,使用此工具。"
|
|
super(RAGSearch, self).__init__(name, description)
|
|
self.execute = self.search
|
|
|
|
@staticmethod
|
|
def _get_sub_task(message: str):
|
|
return message.split("sub_task:")[-1] if "sub_task:" in message else message
|
|
|
|
def search(self, message: str):
|
|
sub_task = self._get_sub_task(message)
|
|
result = "不包含相关内容"
|
|
if self._file:
|
|
response = rag_pipline.query(sub_task,10,self._file)
|
|
results = set([r["text"] for r in response])
|
|
result = "\n\n##\n".join(results)
|
|
return self.normal(result)
|
|
|
|
|
|
rag_search = RAGSearch()
|