380 lines
12 KiB
Python
380 lines
12 KiB
Python
import json
|
||
import os
|
||
import time
|
||
import logging
|
||
from typing import Dict, Any, Optional
|
||
import requests
|
||
from requests.adapters import HTTPAdapter
|
||
from urllib3.util.retry import Retry
|
||
|
||
from jimeng import VolcEngineAPIClient
|
||
|
||
|
||
class AIAgent:
|
||
"""AI代理类,整合了各种AI功能"""
|
||
|
||
def __init__(self, deepseek_api_key: str = "sk-c2f04b4d635b4df0aace0432f3c6d6f4",
|
||
max_retries: int = 3, timeout: int = 30):
|
||
"""初始化AI助手
|
||
|
||
Args:
|
||
deepseek_api_key: DeepSeek API密钥
|
||
max_retries: 最大重试次数
|
||
timeout: 请求超时时间(秒)
|
||
"""
|
||
self.deepseek_api_key = deepseek_api_key
|
||
self.timeout = timeout
|
||
|
||
# 配置日志
|
||
self._setup_logging()
|
||
|
||
# 创建带重试机制的会话
|
||
self.session = self._create_session(max_retries)
|
||
|
||
# 配置常量
|
||
self.BASE_REPORT_URL = "https://a.cmdp.cn/basiceg/v1/json/tpldev"
|
||
self.CACHE_DIR = "cache"
|
||
|
||
def _setup_logging(self):
|
||
"""配置日志"""
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||
)
|
||
self.logger = logging.getLogger(__name__)
|
||
|
||
@staticmethod
|
||
def _create_session( max_retries: int) -> requests.Session:
|
||
"""创建带重试机制的会话
|
||
|
||
Args:
|
||
max_retries: 最大重试次数
|
||
|
||
Returns:
|
||
配置好的会话对象
|
||
"""
|
||
session = requests.Session()
|
||
|
||
# 配置重试策略
|
||
retry_strategy = Retry(
|
||
total=max_retries,
|
||
status_forcelist=[429, 500, 502, 503, 504],
|
||
method_whitelist=["HEAD", "GET", "POST", "PUT", "DELETE", "OPTIONS", "TRACE"],
|
||
backoff_factor=1
|
||
)
|
||
|
||
# 为HTTP和HTTPS适配器添加重试机制
|
||
adapter = HTTPAdapter(max_retries=retry_strategy)
|
||
session.mount("http://", adapter)
|
||
session.mount("https://", adapter)
|
||
|
||
return session
|
||
|
||
def _make_request(self, method: str, url: str, **kwargs) -> Optional[Dict[str, Any]]:
|
||
"""发送HTTP请求(带错误处理)
|
||
|
||
Args:
|
||
method: HTTP方法
|
||
url: 请求URL
|
||
**kwargs: 其他请求参数
|
||
|
||
Returns:
|
||
JSON响应或None
|
||
"""
|
||
try:
|
||
# 设置默认超时
|
||
if 'timeout' not in kwargs:
|
||
kwargs['timeout'] = self.timeout
|
||
|
||
response = self.session.request(method, url, **kwargs)
|
||
response.raise_for_status()
|
||
|
||
self.logger.info(f"请求成功: {url}")
|
||
return response.json()
|
||
|
||
except requests.exceptions.RequestException as e:
|
||
self.logger.error(f"请求失败: {url} - {str(e)}")
|
||
return None
|
||
except json.JSONDecodeError as e:
|
||
self.logger.error(f"JSON解析失败: {url} - {str(e)}")
|
||
return None
|
||
|
||
def chat_deepseek(self, message: str, role: str = "", model: str = "deepseek-chat") -> str:
|
||
"""与DeepSeek聊天
|
||
|
||
Args:
|
||
message: 用户消息
|
||
role: 系统角色
|
||
model: 使用的模型
|
||
|
||
Returns:
|
||
AI回复内容
|
||
"""
|
||
url = "https://api.deepseek.com/chat/completions"
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {self.deepseek_api_key}"
|
||
}
|
||
|
||
messages = []
|
||
if role:
|
||
messages.append({"role": "system", "content": role})
|
||
messages.append({"role": "user", "content": message})
|
||
|
||
body = {
|
||
"model": model,
|
||
"messages": messages,
|
||
"stream": False
|
||
}
|
||
|
||
result = self._make_request("POST", url, json=body, headers=headers)
|
||
if result and "choices" in result:
|
||
return result['choices'][0]['message']['content']
|
||
|
||
self.logger.warning("DeepSeek API调用失败或返回空结果")
|
||
return ""
|
||
|
||
def generate_report_cache_data(self, year: str) -> Dict[str, Any]:
|
||
"""生成报告缓存数据
|
||
|
||
Args:
|
||
year: 年份
|
||
|
||
Returns:
|
||
报告数据
|
||
"""
|
||
self.logger.info(f"开始生成{year}年报告缓存数据")
|
||
|
||
base_url = f"{self.BASE_REPORT_URL}/disclosure_assessment_report_data/{year}/"
|
||
urls = {
|
||
"rank_a": base_url + "Rank_A.json",
|
||
"rank_summary": base_url + "Rank_summary.json",
|
||
"stats": base_url + "stats.json"
|
||
}
|
||
|
||
# 并行获取数据
|
||
data_sources = {}
|
||
for key, url in urls.items():
|
||
self.logger.info(f"获取{key}数据: {url}")
|
||
data = self._make_request("GET", url)
|
||
if data:
|
||
data_sources[key] = data
|
||
else:
|
||
self.logger.error(f"获取{key}数据失败")
|
||
return {}
|
||
|
||
# 处理数据
|
||
rank_a_data = data_sources["rank_a"]
|
||
rank_summary_data = data_sources["rank_summary"]
|
||
stats_data = data_sources["stats"]
|
||
|
||
# 构建数据结构
|
||
data = {}
|
||
for item in rank_a_data:
|
||
data[item["gsdm"]] = {
|
||
"name": item["gsjc"],
|
||
"num": item["num"],
|
||
"ranks": {}
|
||
}
|
||
|
||
# 填充排名数据
|
||
for item in rank_summary_data:
|
||
security_code = item["gsdm"]
|
||
if security_code in data:
|
||
year_key = item["kpnd"]
|
||
rank = item["kpjg"]
|
||
data[security_code]["ranks"][year_key] = rank
|
||
|
||
# 添加统计信息
|
||
data["stats"] = stats_data
|
||
|
||
# 保存到缓存
|
||
self._save_cache(year, data)
|
||
|
||
self.logger.info(f"{year}年报告缓存数据生成完成")
|
||
return data
|
||
|
||
def _save_cache(self, year: str, data: Dict[str, Any]) -> bool:
|
||
"""保存数据到缓存
|
||
|
||
Args:
|
||
year: 年份
|
||
data: 要缓存的数据
|
||
|
||
Returns:
|
||
是否保存成功
|
||
"""
|
||
try:
|
||
os.makedirs(self.CACHE_DIR, exist_ok=True)
|
||
cache_file = os.path.join(self.CACHE_DIR, f"{year}.json")
|
||
|
||
with open(cache_file, "w", encoding="utf-8") as f:
|
||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||
|
||
self.logger.info(f"缓存已保存: {cache_file}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"保存缓存失败: {str(e)}")
|
||
return False
|
||
|
||
def _load_cache(self, year: str) -> Optional[Dict[str, Any]]:
|
||
"""从缓存加载数据
|
||
|
||
Args:
|
||
year: 年份
|
||
|
||
Returns:
|
||
缓存数据或None
|
||
"""
|
||
cache_file = os.path.join(self.CACHE_DIR, f"{year}.json")
|
||
|
||
if not os.path.exists(cache_file):
|
||
return None
|
||
|
||
try:
|
||
with open(cache_file, "r", encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
|
||
self.logger.info(f"缓存已加载: {cache_file}")
|
||
return data
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"加载缓存失败: {str(e)}")
|
||
return None
|
||
|
||
def get_report_data(self, security_code: str, year: str) -> Dict[str, Any]:
|
||
"""获取报告数据
|
||
|
||
Args:
|
||
security_code: 证券代码
|
||
year: 年份
|
||
|
||
Returns:
|
||
报告数据
|
||
"""
|
||
# 尝试从缓存加载
|
||
cached_data = self._load_cache(year)
|
||
|
||
if cached_data is None:
|
||
# 缓存不存在,生成新数据
|
||
cached_data = self.generate_report_cache_data(year)
|
||
|
||
# 返回特定证券的数据
|
||
result = cached_data.get(security_code, {})
|
||
result["stats"] = cached_data.get("stats", {})
|
||
return result
|
||
|
||
def get_image_prompt(self, message: str, security_code: str, year: str) -> str:
|
||
"""获取图像生成提示词
|
||
|
||
Args:
|
||
message: 绘画主题
|
||
security_code: 证券代码
|
||
year: 年份
|
||
|
||
Returns:
|
||
图像生成提示词
|
||
"""
|
||
self.logger.info(f"生成图像提示词: {message}")
|
||
with open("cache/painting.prompt", "r", encoding="utf-8") as f:
|
||
role = f.read()
|
||
# 从缓存中获取特定证券的数据
|
||
security_data = self.get_report_data(security_code, year)
|
||
security_name = security_data.get("name")
|
||
rank = security_data.get("ranks",{}).get(year, "N/A")
|
||
num = security_data.get("num", 0)
|
||
stats = json.dumps(security_data.get("stats", {}), ensure_ascii=False)
|
||
prompt_message = f"## 相关信息:证券名称是{security_name},{year}年排名是{rank},已经连续{num}年考评A级。\n{stats}\n\n## 绘画主题:{message}。"
|
||
prompt = self.chat_deepseek(prompt_message, role)
|
||
|
||
if prompt:
|
||
self.logger.info("图像提示词生成成功")
|
||
else:
|
||
self.logger.warning("图像提示词生成失败")
|
||
|
||
return prompt
|
||
|
||
def get_report_template(self, security_code: str, year: str) -> str:
|
||
"""获取报告模板
|
||
|
||
Args:
|
||
security_code: 证券代码
|
||
year: 年份
|
||
|
||
Returns:
|
||
报告模板内容
|
||
"""
|
||
url = f"{self.BASE_REPORT_URL}/disclosure_assessment_report_prompt/{year}/{security_code}.txt"
|
||
|
||
try:
|
||
response = self.session.get(url, timeout=self.timeout)
|
||
response.raise_for_status()
|
||
|
||
self.logger.info(f"报告模板获取成功: {security_code} - {year}")
|
||
return response.text
|
||
|
||
except requests.exceptions.RequestException as e:
|
||
self.logger.error(f"获取报告模板失败: {security_code} - {year} - {str(e)}")
|
||
return ""
|
||
|
||
def get_report(self, security_code: str, year: str) -> str:
|
||
"""获取报告
|
||
|
||
Args:
|
||
security_code: 证券代码
|
||
year: 年份
|
||
|
||
Returns:
|
||
生成的报告内容
|
||
"""
|
||
self.logger.info(f"开始生成报告: {security_code} - {year}")
|
||
|
||
template = self.get_report_template(security_code, year)
|
||
if not template:
|
||
self.logger.warning("报告模板为空")
|
||
return ""
|
||
|
||
result = self.chat_deepseek(template, "")
|
||
|
||
if result:
|
||
self.logger.info("报告生成成功")
|
||
else:
|
||
self.logger.warning("报告生成失败")
|
||
|
||
return result
|
||
|
||
|
||
def test():
|
||
"""测试函数"""
|
||
# 创建AI助手实例,可配置重试次数和超时时间
|
||
assistant = AIAgent(max_retries=3, timeout=30)
|
||
|
||
try:
|
||
# 测试获取报告数据
|
||
print("测试获取报告数据...")
|
||
data = assistant.get_report_data("000001", "2023")
|
||
print(f"获取到的数据: {data}")
|
||
|
||
# 测试AI聊天
|
||
print("\n测试AI聊天...")
|
||
response = assistant.chat_deepseek("你好,介绍一下你自己", "你是一个有用的助手")
|
||
print(f"AI回复: {response}")
|
||
|
||
# 测试图像提示词生成
|
||
print("\n测试图像提示词生成...")
|
||
prompt = assistant.get_image_prompt("春天的花园")
|
||
print(f"图像提示词: {prompt}")
|
||
|
||
# 测试报告生成
|
||
print("\n测试报告生成...")
|
||
report = assistant.get_report("000001", "2023")
|
||
print(f"报告内容: {report}")
|
||
|
||
except Exception as e:
|
||
print(f"测试过程中发生错误: {str(e)}")
|
||
assistant.logger.error(f"主函数执行错误: {str(e)}")
|
||
|
||
|
||
|
||
agent = AIAgent(max_retries=1, timeout=180) |