Files
Esmart/ai.py
2025-10-01 02:03:21 +08:00

380 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import time
import logging
from typing import Dict, Any, Optional
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
from jimeng import VolcEngineAPIClient
class AIAgent:
"""AI代理类整合了各种AI功能"""
def __init__(self, deepseek_api_key: str = "sk-c2f04b4d635b4df0aace0432f3c6d6f4",
max_retries: int = 3, timeout: int = 30):
"""初始化AI助手
Args:
deepseek_api_key: DeepSeek API密钥
max_retries: 最大重试次数
timeout: 请求超时时间(秒)
"""
self.deepseek_api_key = deepseek_api_key
self.timeout = timeout
# 配置日志
self._setup_logging()
# 创建带重试机制的会话
self.session = self._create_session(max_retries)
# 配置常量
self.BASE_REPORT_URL = "https://a.cmdp.cn/basiceg/v1/json/tpldev"
self.CACHE_DIR = "cache"
def _setup_logging(self):
"""配置日志"""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
self.logger = logging.getLogger(__name__)
@staticmethod
def _create_session( max_retries: int) -> requests.Session:
"""创建带重试机制的会话
Args:
max_retries: 最大重试次数
Returns:
配置好的会话对象
"""
session = requests.Session()
# 配置重试策略
retry_strategy = Retry(
total=max_retries,
status_forcelist=[429, 500, 502, 503, 504],
method_whitelist=["HEAD", "GET", "POST", "PUT", "DELETE", "OPTIONS", "TRACE"],
backoff_factor=1
)
# 为HTTP和HTTPS适配器添加重试机制
adapter = HTTPAdapter(max_retries=retry_strategy)
session.mount("http://", adapter)
session.mount("https://", adapter)
return session
def _make_request(self, method: str, url: str, **kwargs) -> Optional[Dict[str, Any]]:
"""发送HTTP请求带错误处理
Args:
method: HTTP方法
url: 请求URL
**kwargs: 其他请求参数
Returns:
JSON响应或None
"""
try:
# 设置默认超时
if 'timeout' not in kwargs:
kwargs['timeout'] = self.timeout
response = self.session.request(method, url, **kwargs)
response.raise_for_status()
self.logger.info(f"请求成功: {url}")
return response.json()
except requests.exceptions.RequestException as e:
self.logger.error(f"请求失败: {url} - {str(e)}")
return None
except json.JSONDecodeError as e:
self.logger.error(f"JSON解析失败: {url} - {str(e)}")
return None
def chat_deepseek(self, message: str, role: str = "", model: str = "deepseek-chat") -> str:
"""与DeepSeek聊天
Args:
message: 用户消息
role: 系统角色
model: 使用的模型
Returns:
AI回复内容
"""
url = "https://api.deepseek.com/chat/completions"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.deepseek_api_key}"
}
messages = []
if role:
messages.append({"role": "system", "content": role})
messages.append({"role": "user", "content": message})
body = {
"model": model,
"messages": messages,
"stream": False
}
result = self._make_request("POST", url, json=body, headers=headers)
if result and "choices" in result:
return result['choices'][0]['message']['content']
self.logger.warning("DeepSeek API调用失败或返回空结果")
return ""
def generate_report_cache_data(self, year: str) -> Dict[str, Any]:
"""生成报告缓存数据
Args:
year: 年份
Returns:
报告数据
"""
self.logger.info(f"开始生成{year}年报告缓存数据")
base_url = f"{self.BASE_REPORT_URL}/disclosure_assessment_report_data/{year}/"
urls = {
"rank_a": base_url + "Rank_A.json",
"rank_summary": base_url + "Rank_summary.json",
"stats": base_url + "stats.json"
}
# 并行获取数据
data_sources = {}
for key, url in urls.items():
self.logger.info(f"获取{key}数据: {url}")
data = self._make_request("GET", url)
if data:
data_sources[key] = data
else:
self.logger.error(f"获取{key}数据失败")
return {}
# 处理数据
rank_a_data = data_sources["rank_a"]
rank_summary_data = data_sources["rank_summary"]
stats_data = data_sources["stats"]
# 构建数据结构
data = {}
for item in rank_a_data:
data[item["gsdm"]] = {
"name": item["gsjc"],
"num": item["num"],
"ranks": {}
}
# 填充排名数据
for item in rank_summary_data:
security_code = item["gsdm"]
if security_code in data:
year_key = item["kpnd"]
rank = item["kpjg"]
data[security_code]["ranks"][year_key] = rank
# 添加统计信息
data["stats"] = stats_data
# 保存到缓存
self._save_cache(year, data)
self.logger.info(f"{year}年报告缓存数据生成完成")
return data
def _save_cache(self, year: str, data: Dict[str, Any]) -> bool:
"""保存数据到缓存
Args:
year: 年份
data: 要缓存的数据
Returns:
是否保存成功
"""
try:
os.makedirs(self.CACHE_DIR, exist_ok=True)
cache_file = os.path.join(self.CACHE_DIR, f"{year}.json")
with open(cache_file, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
self.logger.info(f"缓存已保存: {cache_file}")
return True
except Exception as e:
self.logger.error(f"保存缓存失败: {str(e)}")
return False
def _load_cache(self, year: str) -> Optional[Dict[str, Any]]:
"""从缓存加载数据
Args:
year: 年份
Returns:
缓存数据或None
"""
cache_file = os.path.join(self.CACHE_DIR, f"{year}.json")
if not os.path.exists(cache_file):
return None
try:
with open(cache_file, "r", encoding="utf-8") as f:
data = json.load(f)
self.logger.info(f"缓存已加载: {cache_file}")
return data
except Exception as e:
self.logger.error(f"加载缓存失败: {str(e)}")
return None
def get_report_data(self, security_code: str, year: str) -> Dict[str, Any]:
"""获取报告数据
Args:
security_code: 证券代码
year: 年份
Returns:
报告数据
"""
# 尝试从缓存加载
cached_data = self._load_cache(year)
if cached_data is None:
# 缓存不存在,生成新数据
cached_data = self.generate_report_cache_data(year)
# 返回特定证券的数据
result = cached_data.get(security_code, {})
result["stats"] = cached_data.get("stats", {})
return result
def get_image_prompt(self, message: str, security_code: str, year: str) -> str:
"""获取图像生成提示词
Args:
message: 绘画主题
security_code: 证券代码
year: 年份
Returns:
图像生成提示词
"""
self.logger.info(f"生成图像提示词: {message}")
with open("cache/painting.prompt", "r", encoding="utf-8") as f:
role = f.read()
# 从缓存中获取特定证券的数据
security_data = self.get_report_data(security_code, year)
security_name = security_data.get("name")
rank = security_data.get("ranks",{}).get(year, "N/A")
num = security_data.get("num", 0)
stats = json.dumps(security_data.get("stats", {}), ensure_ascii=False)
prompt_message = f"## 相关信息:证券名称是{security_name}{year}年排名是{rank},已经连续{num}年考评A级。\n{stats}\n\n## 绘画主题:{message}"
prompt = self.chat_deepseek(prompt_message, role)
if prompt:
self.logger.info("图像提示词生成成功")
else:
self.logger.warning("图像提示词生成失败")
return prompt
def get_report_template(self, security_code: str, year: str) -> str:
"""获取报告模板
Args:
security_code: 证券代码
year: 年份
Returns:
报告模板内容
"""
url = f"{self.BASE_REPORT_URL}/disclosure_assessment_report_prompt/{year}/{security_code}.txt"
try:
response = self.session.get(url, timeout=self.timeout)
response.raise_for_status()
self.logger.info(f"报告模板获取成功: {security_code} - {year}")
return response.text
except requests.exceptions.RequestException as e:
self.logger.error(f"获取报告模板失败: {security_code} - {year} - {str(e)}")
return ""
def get_report(self, security_code: str, year: str) -> str:
"""获取报告
Args:
security_code: 证券代码
year: 年份
Returns:
生成的报告内容
"""
self.logger.info(f"开始生成报告: {security_code} - {year}")
template = self.get_report_template(security_code, year)
if not template:
self.logger.warning("报告模板为空")
return ""
result = self.chat_deepseek(template, "")
if result:
self.logger.info("报告生成成功")
else:
self.logger.warning("报告生成失败")
return result
def test():
"""测试函数"""
# 创建AI助手实例可配置重试次数和超时时间
assistant = AIAgent(max_retries=3, timeout=30)
try:
# 测试获取报告数据
print("测试获取报告数据...")
data = assistant.get_report_data("000001", "2023")
print(f"获取到的数据: {data}")
# 测试AI聊天
print("\n测试AI聊天...")
response = assistant.chat_deepseek("你好,介绍一下你自己", "你是一个有用的助手")
print(f"AI回复: {response}")
# 测试图像提示词生成
print("\n测试图像提示词生成...")
prompt = assistant.get_image_prompt("春天的花园")
print(f"图像提示词: {prompt}")
# 测试报告生成
print("\n测试报告生成...")
report = assistant.get_report("000001", "2023")
print(f"报告内容: {report}")
except Exception as e:
print(f"测试过程中发生错误: {str(e)}")
assistant.logger.error(f"主函数执行错误: {str(e)}")
agent = AIAgent(max_retries=1, timeout=180)